project-llm-trainer 0.12.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_trainer/__init__.py +13 -0
- llm_trainer/base_trainer.py +683 -0
- llm_trainer/checkpoint.py +126 -0
- llm_trainer/dataset.py +335 -0
- llm_trainer/dpo_trainer.py +297 -0
- llm_trainer/ds_checkpoint.py +63 -0
- llm_trainer/eval.py +33 -0
- llm_trainer/generate_utils.py +450 -0
- llm_trainer/grpo_trainer.py +385 -0
- llm_trainer/log.py +65 -0
- llm_trainer/loss.py +268 -0
- llm_trainer/parallel.py +220 -0
- llm_trainer/partition_utils.py +219 -0
- llm_trainer/ppo_trainer.py +521 -0
- llm_trainer/scheduler.py +179 -0
- llm_trainer/sft_trainer.py +97 -0
- llm_trainer/tokenizer.py +162 -0
- llm_trainer/tools.py +116 -0
- llm_trainer/train_configs.py +324 -0
- llm_trainer/trainer.py +34 -0
- llm_trainer/utils.py +547 -0
- project_llm_trainer-0.12.3.data/scripts/calc_intermediate_size +15 -0
- project_llm_trainer-0.12.3.data/scripts/ddp_train +21 -0
- project_llm_trainer-0.12.3.data/scripts/ds_train +17 -0
- project_llm_trainer-0.12.3.data/scripts/plot_log +69 -0
- project_llm_trainer-0.12.3.data/scripts/plot_lr +45 -0
- project_llm_trainer-0.12.3.data/scripts/py_train +12 -0
- project_llm_trainer-0.12.3.data/scripts/smart_train +37 -0
- project_llm_trainer-0.12.3.dist-info/METADATA +9 -0
- project_llm_trainer-0.12.3.dist-info/RECORD +32 -0
- project_llm_trainer-0.12.3.dist-info/WHEEL +5 -0
- project_llm_trainer-0.12.3.dist-info/top_level.txt +1 -0
llm_trainer/loss.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
from typing import List, Optional
|
|
2
|
+
import torch
|
|
3
|
+
from torch import nn
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LMLoss(nn.Module):
|
|
8
|
+
"""
|
|
9
|
+
llm loss
|
|
10
|
+
"""
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
ignore_index: int = -100,
|
|
14
|
+
*,
|
|
15
|
+
critical_tokens: Optional[List[int]] = None,
|
|
16
|
+
critical_alpha: float = 1.0,
|
|
17
|
+
vocab_size: int = 0
|
|
18
|
+
):
|
|
19
|
+
super().__init__()
|
|
20
|
+
self.ignore_index = ignore_index
|
|
21
|
+
self.critical_tokens = critical_tokens
|
|
22
|
+
self.critical_alpha = critical_alpha
|
|
23
|
+
|
|
24
|
+
if critical_tokens and vocab_size > 0:
|
|
25
|
+
self.register_buffer('weights', torch.ones(vocab_size))
|
|
26
|
+
# 为关键token设置权重
|
|
27
|
+
self.weights[self.critical_tokens] = critical_alpha
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
|
31
|
+
# logits shape (batch, seq_len, vocab_size)
|
|
32
|
+
# labels shape (batch, seq_len)
|
|
33
|
+
shift_logits = logits[..., :-1, :].contiguous()
|
|
34
|
+
shift_labels = labels[..., 1:].contiguous()
|
|
35
|
+
|
|
36
|
+
logits = shift_logits.reshape(-1, logits.shape[-1])
|
|
37
|
+
targets = shift_labels.reshape(-1)
|
|
38
|
+
|
|
39
|
+
ce_loss = F.cross_entropy(
|
|
40
|
+
logits,
|
|
41
|
+
targets,
|
|
42
|
+
ignore_index=self.ignore_index,
|
|
43
|
+
weight=self.weights.to(logits.device, dtype=logits.dtype) if self.critical_tokens else None
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
# 添加额外惩罚项(可选)
|
|
47
|
+
# if self.critical_tokens:
|
|
48
|
+
# crit_mask = torch.isin(targets, torch.tensor(self.critical_tokens).to(targets.device))
|
|
49
|
+
# crit_logits = logits[crit_mask]
|
|
50
|
+
# crit_targets = targets[crit_mask]
|
|
51
|
+
# extra_loss = F.cross_entropy(crit_logits, crit_targets, ignore_index=self.ignore_index)
|
|
52
|
+
# return ce_loss + extra_loss * (self.critical_alpha - 1) # 增强惩罚
|
|
53
|
+
|
|
54
|
+
return ce_loss
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class KDLoss(nn.Module):
|
|
58
|
+
"""
|
|
59
|
+
Language Model Knowledge Distillation Loss
|
|
60
|
+
https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/models/loss.py#L266
|
|
61
|
+
"""
|
|
62
|
+
def __init__(self, ignore_index: int = -100):
|
|
63
|
+
super().__init__()
|
|
64
|
+
self.ignore_index = ignore_index
|
|
65
|
+
|
|
66
|
+
def forward(self, logits: torch.Tensor, teacher_logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
|
67
|
+
teacher_probs = F.softmax(teacher_logits, dim=-1, dtype=torch.float32)
|
|
68
|
+
inf_mask = torch.isinf(logits)
|
|
69
|
+
|
|
70
|
+
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
|
|
71
|
+
prod_probs = torch.masked_fill(teacher_probs * logprobs, inf_mask, 0)
|
|
72
|
+
|
|
73
|
+
x = torch.sum(prod_probs, dim=-1).view(-1)
|
|
74
|
+
mask = (labels != self.ignore_index).int()
|
|
75
|
+
|
|
76
|
+
distil_loss = -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)
|
|
77
|
+
|
|
78
|
+
return distil_loss
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class DPOLoss(nn.Module):
|
|
82
|
+
def __init__(
|
|
83
|
+
self,
|
|
84
|
+
beta: float,
|
|
85
|
+
label_smoothing: float = 0.0,
|
|
86
|
+
ipo: bool = False
|
|
87
|
+
):
|
|
88
|
+
super().__init__()
|
|
89
|
+
self.beta = beta
|
|
90
|
+
self.label_smoothing = label_smoothing
|
|
91
|
+
self.ipo = ipo
|
|
92
|
+
|
|
93
|
+
def forward(
|
|
94
|
+
self,
|
|
95
|
+
policy_chosen_logps: torch.Tensor,
|
|
96
|
+
policy_reject_logps: torch.Tensor,
|
|
97
|
+
ref_chosen_logps: torch.Tensor,
|
|
98
|
+
ref_reject_logps: torch.Tensor
|
|
99
|
+
) -> torch.Tensor:
|
|
100
|
+
pi_logratios = policy_chosen_logps - policy_reject_logps
|
|
101
|
+
ref_logratios = ref_chosen_logps - ref_reject_logps
|
|
102
|
+
logits = pi_logratios - ref_logratios
|
|
103
|
+
|
|
104
|
+
if self.ipo:
|
|
105
|
+
losses = (logits - 1 / (2 * self.beta)) ** 2 # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf
|
|
106
|
+
else:
|
|
107
|
+
# Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf)
|
|
108
|
+
losses = (
|
|
109
|
+
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
|
|
110
|
+
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
loss = losses.mean()
|
|
114
|
+
|
|
115
|
+
# chosen_rewards = self.beta * (policy_chosen_probs - ref_chosen_probs).detach()
|
|
116
|
+
# rejected_rewards = self.beta * (policy_reject_probs - ref_reject_probs).detach()
|
|
117
|
+
|
|
118
|
+
return loss
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class PPOLoss(nn.Module):
|
|
122
|
+
"""
|
|
123
|
+
PPO (Proximal Policy Optimization) 损失函数。
|
|
124
|
+
这个类统一计算 Actor 和 Value 的损失。
|
|
125
|
+
"""
|
|
126
|
+
|
|
127
|
+
def __init__(
|
|
128
|
+
self,
|
|
129
|
+
clip_eps: float,
|
|
130
|
+
vf_coef: float,
|
|
131
|
+
):
|
|
132
|
+
"""
|
|
133
|
+
初始化PPO损失函数。
|
|
134
|
+
:param clip_eps: PPO裁剪范围的epsilon值。
|
|
135
|
+
:param vf_coef: 价值函数损失的系数。
|
|
136
|
+
"""
|
|
137
|
+
super().__init__()
|
|
138
|
+
self.clip_eps = clip_eps
|
|
139
|
+
self.vf_coef = vf_coef
|
|
140
|
+
|
|
141
|
+
def forward(
|
|
142
|
+
self,
|
|
143
|
+
log_probs: torch.Tensor,
|
|
144
|
+
old_log_probs: torch.Tensor,
|
|
145
|
+
values: torch.Tensor,
|
|
146
|
+
old_values: torch.Tensor,
|
|
147
|
+
returns: torch.Tensor,
|
|
148
|
+
advantages: torch.Tensor,
|
|
149
|
+
mask: torch.Tensor
|
|
150
|
+
):
|
|
151
|
+
"""
|
|
152
|
+
计算PPO的总损失、Actor损失和Value损失。
|
|
153
|
+
|
|
154
|
+
:param log_probs: 当前策略的log probabilities, 形状: [batch_size, seq_len]
|
|
155
|
+
:param old_log_probs: 生成rollout时的旧策略的log probabilities, 形状: [batch_size, seq_len]
|
|
156
|
+
:param values: 当前评论家模型输出的价值, 形状: [batch_size, seq_len]
|
|
157
|
+
:param old_values: 生成rollout时的旧价值, 形状: [batch_size, seq_len]
|
|
158
|
+
:param returns: GAE计算出的回报, 形状: [batch_size, seq_len]
|
|
159
|
+
:param advantages: GAE计算出的优势, 形状: [batch_size, seq_len]
|
|
160
|
+
:param mask: 掩码,只计算生成部分的损失, 形状: [batch_size, seq_len]
|
|
161
|
+
:return: (总损失, Actor损失, Value损失, Entropy)
|
|
162
|
+
"""
|
|
163
|
+
# Value Loss (价值损失) with clipping
|
|
164
|
+
values_clipped = old_values + torch.clamp(values - old_values, -self.clip_eps, self.clip_eps)
|
|
165
|
+
vf_loss_unclipped = F.mse_loss(values, returns, reduction='none')
|
|
166
|
+
vf_loss_clipped = F.mse_loss(values_clipped, returns, reduction='none')
|
|
167
|
+
value_loss = torch.max(vf_loss_unclipped, vf_loss_clipped)
|
|
168
|
+
# Apply mask and average
|
|
169
|
+
value_loss = 0.5 * (value_loss * mask).sum() / mask.sum().clamp(min=1.0)
|
|
170
|
+
value_loss = value_loss * self.vf_coef
|
|
171
|
+
|
|
172
|
+
# Actor Loss (策略损失)
|
|
173
|
+
# 计算新旧策略的概率比 r_t = exp(log_prob_new - log_prob_old)
|
|
174
|
+
# ratio 形状: [batch_size, seq_len]
|
|
175
|
+
ratio = torch.exp(log_probs - old_log_probs)
|
|
176
|
+
|
|
177
|
+
# PPO裁剪替代目标(Clipped Surrogate Objective)
|
|
178
|
+
# surr1 形状: [batch_size, seq_len]
|
|
179
|
+
surr1 = ratio * advantages
|
|
180
|
+
# surr2 形状: [batch_size, seq_len]
|
|
181
|
+
surr2 = torch.clamp(ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps) * advantages
|
|
182
|
+
|
|
183
|
+
# 取两者中较小的一个,并加负号(因为我们要最大化这个目标,所以最小化它的负值)
|
|
184
|
+
# 我们只关心生成部分(由mask标记)的损失
|
|
185
|
+
actor_loss = -torch.sum(torch.min(surr1, surr2) * mask) / torch.sum(mask).clamp(min=1.0)
|
|
186
|
+
|
|
187
|
+
# 总损失
|
|
188
|
+
total_loss = actor_loss + value_loss
|
|
189
|
+
|
|
190
|
+
with torch.no_grad():
|
|
191
|
+
# 计算近似KL散度
|
|
192
|
+
logratios = log_probs - old_log_probs
|
|
193
|
+
approx_kl = torch.sum(((torch.exp(logratios) - 1) - logratios) * mask) / mask.sum().clamp(min=1.0)
|
|
194
|
+
|
|
195
|
+
# 计算裁剪比例
|
|
196
|
+
clipped = ratio.gt(1.0 + self.clip_eps) | ratio.lt(1.0 - self.clip_eps)
|
|
197
|
+
clip_frac = torch.sum(clipped.float() * mask) / mask.sum().clamp(min=1.0)
|
|
198
|
+
|
|
199
|
+
return total_loss, actor_loss, value_loss, approx_kl, clip_frac
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
class GRPOLoss(nn.Module):
|
|
203
|
+
def __init__(
|
|
204
|
+
self,
|
|
205
|
+
beta: float,
|
|
206
|
+
clip_eps_low: float,
|
|
207
|
+
clip_eps_high: Optional[float] = None,
|
|
208
|
+
delta: Optional[float] = None,
|
|
209
|
+
importance_sampling_level: str = 'token',
|
|
210
|
+
loss_type: str = 'grpo',
|
|
211
|
+
gen_max_new_tokens: Optional[float] = None
|
|
212
|
+
):
|
|
213
|
+
super().__init__()
|
|
214
|
+
|
|
215
|
+
self.beta = beta
|
|
216
|
+
self.clip_eps_low = clip_eps_low
|
|
217
|
+
self.clip_eps_high = clip_eps_high if clip_eps_high else clip_eps_low
|
|
218
|
+
self.delta = delta
|
|
219
|
+
self.importance_sampling_level = importance_sampling_level
|
|
220
|
+
self.loss_type = loss_type
|
|
221
|
+
self.gen_max_new_tokens = gen_max_new_tokens
|
|
222
|
+
|
|
223
|
+
def forward(
|
|
224
|
+
self,
|
|
225
|
+
log_probs: torch.Tensor,
|
|
226
|
+
old_log_probs: torch.Tensor,
|
|
227
|
+
ref_log_probs: torch.Tensor,
|
|
228
|
+
completion_mask: torch.Tensor,
|
|
229
|
+
advantages: torch.Tensor
|
|
230
|
+
) -> torch.Tensor:
|
|
231
|
+
|
|
232
|
+
if self.beta != 0.0:
|
|
233
|
+
per_token_kl = torch.exp(ref_log_probs - log_probs) - (ref_log_probs - log_probs) - 1
|
|
234
|
+
else:
|
|
235
|
+
per_token_kl = None
|
|
236
|
+
|
|
237
|
+
log_ratio = log_probs - old_log_probs
|
|
238
|
+
if self.importance_sampling_level == "seq":
|
|
239
|
+
# GSPO
|
|
240
|
+
log_importance_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)
|
|
241
|
+
log_importance_weights = log_importance_weights.unsqueeze(-1)
|
|
242
|
+
else:
|
|
243
|
+
# GRPO
|
|
244
|
+
log_importance_weights = log_ratio
|
|
245
|
+
|
|
246
|
+
coef_1 = torch.exp(log_importance_weights)
|
|
247
|
+
coef_2 = torch.clamp(coef_1, 1 - self.clip_eps_low, 1 + self.clip_eps_high)
|
|
248
|
+
|
|
249
|
+
# Two-sided clipping
|
|
250
|
+
if self.delta is not None:
|
|
251
|
+
coef_1 = torch.clamp(coef_1, max=self.delta)
|
|
252
|
+
|
|
253
|
+
per_token_loss1 = coef_1 * advantages
|
|
254
|
+
per_token_loss2 = coef_2 * advantages
|
|
255
|
+
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
|
|
256
|
+
|
|
257
|
+
if self.beta != 0.0:
|
|
258
|
+
per_token_loss = per_token_loss + self.beta * per_token_kl
|
|
259
|
+
|
|
260
|
+
if self.loss_type == "bnpo":
|
|
261
|
+
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0)
|
|
262
|
+
elif self.loss_type == "dr_grpo":
|
|
263
|
+
assert self.gen_max_new_tokens is not None
|
|
264
|
+
loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.gen_max_new_tokens)
|
|
265
|
+
else:
|
|
266
|
+
loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean()
|
|
267
|
+
|
|
268
|
+
return loss
|
llm_trainer/parallel.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Optional, Tuple
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import nn
|
|
7
|
+
import torch.distributed as dist
|
|
8
|
+
from torch.utils.data import Dataset, DataLoader
|
|
9
|
+
from torch.utils.data.distributed import DistributedSampler
|
|
10
|
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
import deepspeed
|
|
14
|
+
except: ...
|
|
15
|
+
|
|
16
|
+
from .log import Logger
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Parallel(ABC):
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
_init_process_group: bool = True,
|
|
23
|
+
_use_parallel: bool = True
|
|
24
|
+
):
|
|
25
|
+
self._initialize(_init_process_group, _use_parallel)
|
|
26
|
+
|
|
27
|
+
def _initialize(
|
|
28
|
+
self,
|
|
29
|
+
_init_process_group: bool,
|
|
30
|
+
_use_parallel: bool
|
|
31
|
+
):
|
|
32
|
+
self._global_rank: int = int(os.environ.get('RANK', -1))
|
|
33
|
+
self._local_rank: int = int(os.environ.get('LOCAL_RANK', -1))
|
|
34
|
+
self._use_parallel: bool = _use_parallel and self._global_rank != -1
|
|
35
|
+
|
|
36
|
+
self._sampler: Optional[DistributedSampler] = None
|
|
37
|
+
self.model: Optional[nn.Module] = None
|
|
38
|
+
|
|
39
|
+
try:
|
|
40
|
+
torch.set_float32_matmul_precision('high')
|
|
41
|
+
torch.backends.cuda.matmul.allow_tf32 = True
|
|
42
|
+
torch.backends.cudnn.allow_tf32 = True
|
|
43
|
+
except:
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
if self._use_parallel:
|
|
47
|
+
if _init_process_group:
|
|
48
|
+
dist.init_process_group(backend='nccl')
|
|
49
|
+
|
|
50
|
+
self.device: str = f'cuda:{self._local_rank}'
|
|
51
|
+
self.device_type: str = 'cuda'
|
|
52
|
+
torch.cuda.set_device(self.device)
|
|
53
|
+
|
|
54
|
+
Logger.std_log(f'global_rank={self._global_rank}, local_rank={self._local_rank}, world_size={self.world_size}')
|
|
55
|
+
else:
|
|
56
|
+
device = "cpu"
|
|
57
|
+
if torch.cuda.is_available():
|
|
58
|
+
device = "cuda"
|
|
59
|
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
60
|
+
device = "mps"
|
|
61
|
+
|
|
62
|
+
self.device: str = device
|
|
63
|
+
self.device_type: str = device
|
|
64
|
+
|
|
65
|
+
@abstractmethod
|
|
66
|
+
def process(
|
|
67
|
+
self,
|
|
68
|
+
model: nn.Module,
|
|
69
|
+
optimizer: torch.optim.Optimizer,
|
|
70
|
+
kwargs: Optional[dict] = None,
|
|
71
|
+
save_instance: bool = True
|
|
72
|
+
) -> Tuple[nn.Module, torch.optim.Optimizer]: ...
|
|
73
|
+
|
|
74
|
+
def process_dataloader(
|
|
75
|
+
self,
|
|
76
|
+
dataset: Dataset,
|
|
77
|
+
data_loader_kwargs: dict,
|
|
78
|
+
sampler_kwargs: Optional[dict]=None
|
|
79
|
+
) -> DataLoader:
|
|
80
|
+
"""
|
|
81
|
+
:param dataset:
|
|
82
|
+
:param data_loader_kwargs
|
|
83
|
+
"batch_size" int,
|
|
84
|
+
"pin_memory" bool,
|
|
85
|
+
"collate_fn" collate_fn,
|
|
86
|
+
"num_workers" int
|
|
87
|
+
"shuffle" bool
|
|
88
|
+
"drop_last" bool
|
|
89
|
+
:param sampler_kwargs:
|
|
90
|
+
"shuffle" bool
|
|
91
|
+
"drop_last" bool
|
|
92
|
+
:return:
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
if self._use_parallel:
|
|
96
|
+
self._sampler = DistributedSampler(dataset=dataset, **sampler_kwargs)
|
|
97
|
+
return DataLoader(dataset=dataset, sampler=self._sampler, **data_loader_kwargs)
|
|
98
|
+
|
|
99
|
+
return DataLoader(dataset=dataset, **data_loader_kwargs)
|
|
100
|
+
|
|
101
|
+
def on_epoch_start(self, epoch):
|
|
102
|
+
if self._sampler:
|
|
103
|
+
self._sampler.set_epoch(epoch)
|
|
104
|
+
|
|
105
|
+
def on_epoch_end(self, epoch): ...
|
|
106
|
+
|
|
107
|
+
def synchronize(self):
|
|
108
|
+
if self._use_parallel:
|
|
109
|
+
torch.cuda.synchronize(device=self.device)
|
|
110
|
+
|
|
111
|
+
def destroy(self):
|
|
112
|
+
if self._use_parallel:
|
|
113
|
+
dist.destroy_process_group()
|
|
114
|
+
|
|
115
|
+
@property
|
|
116
|
+
def parallel_train(self) -> bool:
|
|
117
|
+
return self._use_parallel
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def is_main_process(self) -> bool:
|
|
121
|
+
if self._use_parallel:
|
|
122
|
+
return self._global_rank == 0
|
|
123
|
+
|
|
124
|
+
return True
|
|
125
|
+
|
|
126
|
+
@property
|
|
127
|
+
def world_size(self) -> int:
|
|
128
|
+
if self._use_parallel:
|
|
129
|
+
return dist.get_world_size()
|
|
130
|
+
return 1
|
|
131
|
+
|
|
132
|
+
def wait(self, msg=None):
|
|
133
|
+
if self.world_size == 1:
|
|
134
|
+
return
|
|
135
|
+
|
|
136
|
+
msg = f' for {msg}' if msg else ''
|
|
137
|
+
Logger.std_log(f'wait at {self.device}{msg}')
|
|
138
|
+
dist.barrier()
|
|
139
|
+
Logger.std_log(f'continue at {self.device}{msg}')
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
class DsParallel(Parallel):
|
|
143
|
+
def __init__(self):
|
|
144
|
+
deepspeed.init_distributed(dist_backend='nccl')
|
|
145
|
+
super().__init__(_init_process_group=False)
|
|
146
|
+
|
|
147
|
+
def process(
|
|
148
|
+
self,
|
|
149
|
+
model: nn.Module,
|
|
150
|
+
optimizer: torch.optim.Optimizer,
|
|
151
|
+
kwargs: Optional[dict] = None,
|
|
152
|
+
save_instance: bool = True
|
|
153
|
+
) -> Tuple[nn.Module, torch.optim.Optimizer]:
|
|
154
|
+
"""
|
|
155
|
+
:param model:
|
|
156
|
+
:param optimizer:
|
|
157
|
+
:param kwargs:
|
|
158
|
+
参考deepspeed配置
|
|
159
|
+
:param save_instance
|
|
160
|
+
:return:
|
|
161
|
+
"""
|
|
162
|
+
model, optim, _, _ = deepspeed.initialize(
|
|
163
|
+
model=model,
|
|
164
|
+
optimizer=optimizer,
|
|
165
|
+
dist_init_required=False,
|
|
166
|
+
config_params=kwargs
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
if save_instance:
|
|
170
|
+
self.model = model
|
|
171
|
+
|
|
172
|
+
return model, optim
|
|
173
|
+
|
|
174
|
+
def synchronize(self): ...
|
|
175
|
+
|
|
176
|
+
def destroy(self): ...
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class DdpParallel(Parallel):
|
|
180
|
+
def __init__(self):
|
|
181
|
+
super().__init__()
|
|
182
|
+
|
|
183
|
+
def process(
|
|
184
|
+
self,
|
|
185
|
+
model: nn.Module,
|
|
186
|
+
optimizer: torch.optim.Optimizer,
|
|
187
|
+
kwargs: Optional[dict] = None,
|
|
188
|
+
save_instance: bool = True
|
|
189
|
+
) -> Tuple[nn.Module, torch.optim.Optimizer]:
|
|
190
|
+
model.to(self.device)
|
|
191
|
+
|
|
192
|
+
if self._use_parallel:
|
|
193
|
+
# self.model = DDP(module=model, broadcast_buffers=False, find_unused_parameters=True)
|
|
194
|
+
model = DDP(module=model, device_ids=[self._local_rank], output_device=self._local_rank)
|
|
195
|
+
else:
|
|
196
|
+
model = model
|
|
197
|
+
|
|
198
|
+
if save_instance:
|
|
199
|
+
self.model = model
|
|
200
|
+
|
|
201
|
+
return model, optimizer
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
class NoneParallel(Parallel):
|
|
205
|
+
def __init__(self):
|
|
206
|
+
super().__init__(_use_parallel=False)
|
|
207
|
+
|
|
208
|
+
def process(
|
|
209
|
+
self,
|
|
210
|
+
model: nn.Module,
|
|
211
|
+
optimizer: torch.optim.Optimizer,
|
|
212
|
+
kwargs: Optional[dict] = None,
|
|
213
|
+
save_instance: bool = True
|
|
214
|
+
) -> Tuple[nn.Module, torch.optim.Optimizer]:
|
|
215
|
+
model.to(self.device)
|
|
216
|
+
|
|
217
|
+
if save_instance:
|
|
218
|
+
self.model = model
|
|
219
|
+
|
|
220
|
+
return model, optimizer
|
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
from contextlib import contextmanager
|
|
3
|
+
import itertools
|
|
4
|
+
from packaging import version
|
|
5
|
+
from torch import nn
|
|
6
|
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
7
|
+
import torch.distributed as dist
|
|
8
|
+
|
|
9
|
+
from .tools import TrainerTools
|
|
10
|
+
from .parallel import DsParallel, DdpParallel
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@contextmanager
|
|
14
|
+
def unwrap_model_for_generation(model: nn.Module):
|
|
15
|
+
"""
|
|
16
|
+
Context manager to unwrap distributed or accelerated models for generation tasks.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
model:
|
|
20
|
+
Model to be unwrapped.
|
|
21
|
+
Yields:
|
|
22
|
+
Unwrapped model.
|
|
23
|
+
|
|
24
|
+
Example:
|
|
25
|
+
```python
|
|
26
|
+
with unwrap_model_for_generation(model, accelerator) as unwrapped_model:
|
|
27
|
+
generated_outputs = unwrapped_model.generate(input_ids)
|
|
28
|
+
```
|
|
29
|
+
"""
|
|
30
|
+
if isinstance(TrainerTools().parallel, DsParallel):
|
|
31
|
+
import deepspeed
|
|
32
|
+
assert isinstance(model, deepspeed.DeepSpeedEngine)
|
|
33
|
+
|
|
34
|
+
if model.zero_optimization_stage() == 3:
|
|
35
|
+
with deepspeed.zero.GatheredParameters(model.parameters()):
|
|
36
|
+
_remove_hooks(model)
|
|
37
|
+
yield unwrap_model(model)
|
|
38
|
+
_add_hooks(model)
|
|
39
|
+
else:
|
|
40
|
+
yield unwrap_model(model)
|
|
41
|
+
elif isinstance(TrainerTools().parallel, DdpParallel):
|
|
42
|
+
yield unwrap_model(model)
|
|
43
|
+
else:
|
|
44
|
+
yield model
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def sync_model_params(_from: nn.Module, _to: Optional[nn.Module], mixup_alpha: float = 1.0):
|
|
48
|
+
"""
|
|
49
|
+
必须在所有rank上调用,非rank0, _to 可以设置为None.
|
|
50
|
+
当前函数不适用于_to是一个zero3模型
|
|
51
|
+
"""
|
|
52
|
+
if isinstance(TrainerTools().parallel, DsParallel):
|
|
53
|
+
state_dict = _get_ds_model_params(_from, only_rank0=_to is None)
|
|
54
|
+
elif isinstance(_from, DDP):
|
|
55
|
+
state_dict = _from.module.state_dict()
|
|
56
|
+
else:
|
|
57
|
+
state_dict = _from.state_dict()
|
|
58
|
+
|
|
59
|
+
if not _to or not state_dict:
|
|
60
|
+
return
|
|
61
|
+
|
|
62
|
+
unwrap_to_model = unwrap_model(_to)
|
|
63
|
+
if mixup_alpha == 1.0:
|
|
64
|
+
# 直接覆盖
|
|
65
|
+
unwrap_to_model.load_state_dict(state_dict, strict=False)
|
|
66
|
+
else:
|
|
67
|
+
# 混合参数
|
|
68
|
+
for param_name, target_param in unwrap_to_model.named_parameters():
|
|
69
|
+
if param_name in state_dict:
|
|
70
|
+
from_param_tensor = state_dict[param_name]
|
|
71
|
+
target_param.data.mul_(1.0 - mixup_alpha).add_(
|
|
72
|
+
from_param_tensor.data.to(target_param.device),
|
|
73
|
+
alpha=mixup_alpha
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def unwrap_model(model) -> nn.Module:
|
|
78
|
+
try:
|
|
79
|
+
import deepspeed
|
|
80
|
+
if isinstance(model, deepspeed.DeepSpeedEngine):
|
|
81
|
+
return model.module
|
|
82
|
+
except: ...
|
|
83
|
+
|
|
84
|
+
if isinstance(model, DDP):
|
|
85
|
+
return model.module
|
|
86
|
+
|
|
87
|
+
return model
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def _get_ds_full_state_dict_on_rank0(model: nn.Module) -> Optional[dict]:
|
|
91
|
+
"""
|
|
92
|
+
需要在所有rank上调用,然后只有rank0有值
|
|
93
|
+
"""
|
|
94
|
+
import deepspeed
|
|
95
|
+
assert isinstance(model, deepspeed.DeepSpeedEngine)
|
|
96
|
+
|
|
97
|
+
if model.zero_optimization_stage() != 3:
|
|
98
|
+
if TrainerTools().parallel.is_main_process:
|
|
99
|
+
return {k: v.cpu().clone() for k, v in model.module.state_dict().items()}
|
|
100
|
+
return None
|
|
101
|
+
|
|
102
|
+
# --- ZeRO-3 ---
|
|
103
|
+
# 只调用一次 GatheredParameters,传入所有参数
|
|
104
|
+
with deepspeed.zero.GatheredParameters(model.parameters(), modifier_rank=0):
|
|
105
|
+
if TrainerTools().parallel.is_main_process:
|
|
106
|
+
# 在这个 'with' 代码块内,rank 0 上的 model.module 拥有完整的参数
|
|
107
|
+
# 所以我们可以像操作普通模型一样直接调用 state_dict()
|
|
108
|
+
full_state_dict = model.module.state_dict()
|
|
109
|
+
|
|
110
|
+
# 将其克隆到 CPU 并返回
|
|
111
|
+
return {k: v.cpu().clone() for k, v in full_state_dict.items()}
|
|
112
|
+
|
|
113
|
+
# 其他 rank 执行到这里时,上下文结束,直接返回 None
|
|
114
|
+
return None
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _get_ds_model_params(model: nn.Module, only_rank0=False):
|
|
118
|
+
"""
|
|
119
|
+
从一个正在运行的 DeepSpeedEngine 中高效地提取完整的 FP32 state_dict,
|
|
120
|
+
兼容 ZeRO Stages 0, 1, 2, 3。
|
|
121
|
+
包含了对 ZeRO-3 中分片参数的正确处理。
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
import deepspeed
|
|
125
|
+
assert isinstance(model, deepspeed.DeepSpeedEngine)
|
|
126
|
+
state_dict = _get_ds_full_state_dict_on_rank0(model)
|
|
127
|
+
|
|
128
|
+
# 现在,只有 rank 0 上的 state_dict 是一个有效的字典,其他 rank 上是 None。
|
|
129
|
+
# 我们需要将其广播给所有进程。
|
|
130
|
+
if not only_rank0 and TrainerTools().parallel.world_size > 1:
|
|
131
|
+
# 准备一个列表,rank 0 有数据,其他 rank 是占位符
|
|
132
|
+
object_list = [state_dict] if TrainerTools().parallel.is_main_process else [None]
|
|
133
|
+
# 执行广播,这个操作是阻塞的,会同步所有进程
|
|
134
|
+
dist.broadcast_object_list(object_list, src=0)
|
|
135
|
+
# 所有进程从列表中获取广播后的 state_dict 副本
|
|
136
|
+
state_dict = object_list[0]
|
|
137
|
+
|
|
138
|
+
return state_dict
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def _copy_params(model, target_model, mixup_alpha):
|
|
142
|
+
for target_param, copy_param in zip(target_model.parameters(), model.parameters()):
|
|
143
|
+
target_param.data.mul_(1.0 - mixup_alpha).add_(copy_param.data, alpha=mixup_alpha)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _sync_ds_model_params(_from: nn.Module, _to: Optional[nn.Module], mixup_alpha: float = 1.0):
|
|
147
|
+
import deepspeed
|
|
148
|
+
assert isinstance(_from, deepspeed.DeepSpeedEngine)
|
|
149
|
+
|
|
150
|
+
origin_from = unwrap_model(_from)
|
|
151
|
+
|
|
152
|
+
if _from.zero_optimization_stage() == 3:
|
|
153
|
+
with deepspeed.zero.GatheredParameters(list(origin_from.parameters()) + list(_to.parameters()), modifier_rank=0):
|
|
154
|
+
# why only rank 0?
|
|
155
|
+
if TrainerTools().parallel.is_main_process:
|
|
156
|
+
_copy_params(origin_from, _to, mixup_alpha)
|
|
157
|
+
else:
|
|
158
|
+
_copy_params(origin_from, _to, mixup_alpha)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def _sync_ddp_model_params(_from: nn.Module, _to: Optional[nn.Module], mixup_alpha: float = 1.0):
|
|
162
|
+
assert isinstance(_from, DDP)
|
|
163
|
+
|
|
164
|
+
origin_from = unwrap_model(_from)
|
|
165
|
+
_copy_params(origin_from, _to, mixup_alpha)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def _add_hooks(model: nn.Module) -> None:
|
|
169
|
+
"""Adds the optimizer hooks from a DeepSpeed ZeRO-3 model."""
|
|
170
|
+
import deepspeed
|
|
171
|
+
assert isinstance(model, deepspeed.DeepSpeedEngine)
|
|
172
|
+
|
|
173
|
+
if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer
|
|
174
|
+
return
|
|
175
|
+
if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
|
|
176
|
+
optimizer_offload = model.optimizer.parameter_offload
|
|
177
|
+
elif model.optimizer is not None:
|
|
178
|
+
optimizer_offload = model.optimizer
|
|
179
|
+
else:
|
|
180
|
+
raise RuntimeError("The model optimizer is None, which is not yet supported.")
|
|
181
|
+
if version.parse(deepspeed.__version__) >= version.parse("0.16.4"):
|
|
182
|
+
# Account for renaming in https://github.com/deepspeedai/DeepSpeed/pull/6847
|
|
183
|
+
optimizer_offload._register_deepspeed_module(optimizer_offload.module)
|
|
184
|
+
else:
|
|
185
|
+
optimizer_offload._register_hooks_recursively(optimizer_offload.module)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def _remove_hooks(model: nn.Module) -> None:
|
|
189
|
+
"""Removes the optimizer hooks from a DeepSpeed ZeRO-3 model."""
|
|
190
|
+
import deepspeed
|
|
191
|
+
assert isinstance(model, deepspeed.DeepSpeedEngine)
|
|
192
|
+
|
|
193
|
+
if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer
|
|
194
|
+
return
|
|
195
|
+
if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
|
|
196
|
+
optimizer_offload = model.optimizer.parameter_offload
|
|
197
|
+
elif model.optimizer is not None:
|
|
198
|
+
optimizer_offload = model.optimizer
|
|
199
|
+
else:
|
|
200
|
+
raise RuntimeError("The model optimizer is None, which is not yet supported.")
|
|
201
|
+
|
|
202
|
+
for param in _iter_params(optimizer_offload.module, recurse=True):
|
|
203
|
+
param.ds_active_sub_modules.clear()
|
|
204
|
+
|
|
205
|
+
for hook in optimizer_offload.forward_hooks:
|
|
206
|
+
hook.remove()
|
|
207
|
+
for hook in optimizer_offload.backward_hooks:
|
|
208
|
+
hook.remove()
|
|
209
|
+
|
|
210
|
+
optimizer_offload.forward_hooks = []
|
|
211
|
+
optimizer_offload.backward_hooks = []
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def _iter_params(module, recurse=False):
|
|
215
|
+
return [param for _, param in _get_all_parameters(module, recurse)]
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def _get_all_parameters(sub_module, recurse=False):
|
|
219
|
+
return itertools.chain(sub_module.named_parameters(recurse=recurse), sub_module.ds_external_parameters())
|