project-llm-trainer 0.12.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_trainer/loss.py ADDED
@@ -0,0 +1,268 @@
1
+ from typing import List, Optional
2
+ import torch
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class LMLoss(nn.Module):
8
+ """
9
+ llm loss
10
+ """
11
+ def __init__(
12
+ self,
13
+ ignore_index: int = -100,
14
+ *,
15
+ critical_tokens: Optional[List[int]] = None,
16
+ critical_alpha: float = 1.0,
17
+ vocab_size: int = 0
18
+ ):
19
+ super().__init__()
20
+ self.ignore_index = ignore_index
21
+ self.critical_tokens = critical_tokens
22
+ self.critical_alpha = critical_alpha
23
+
24
+ if critical_tokens and vocab_size > 0:
25
+ self.register_buffer('weights', torch.ones(vocab_size))
26
+ # 为关键token设置权重
27
+ self.weights[self.critical_tokens] = critical_alpha
28
+
29
+
30
+ def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
31
+ # logits shape (batch, seq_len, vocab_size)
32
+ # labels shape (batch, seq_len)
33
+ shift_logits = logits[..., :-1, :].contiguous()
34
+ shift_labels = labels[..., 1:].contiguous()
35
+
36
+ logits = shift_logits.reshape(-1, logits.shape[-1])
37
+ targets = shift_labels.reshape(-1)
38
+
39
+ ce_loss = F.cross_entropy(
40
+ logits,
41
+ targets,
42
+ ignore_index=self.ignore_index,
43
+ weight=self.weights.to(logits.device, dtype=logits.dtype) if self.critical_tokens else None
44
+ )
45
+
46
+ # 添加额外惩罚项(可选)
47
+ # if self.critical_tokens:
48
+ # crit_mask = torch.isin(targets, torch.tensor(self.critical_tokens).to(targets.device))
49
+ # crit_logits = logits[crit_mask]
50
+ # crit_targets = targets[crit_mask]
51
+ # extra_loss = F.cross_entropy(crit_logits, crit_targets, ignore_index=self.ignore_index)
52
+ # return ce_loss + extra_loss * (self.critical_alpha - 1) # 增强惩罚
53
+
54
+ return ce_loss
55
+
56
+
57
+ class KDLoss(nn.Module):
58
+ """
59
+ Language Model Knowledge Distillation Loss
60
+ https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/models/loss.py#L266
61
+ """
62
+ def __init__(self, ignore_index: int = -100):
63
+ super().__init__()
64
+ self.ignore_index = ignore_index
65
+
66
+ def forward(self, logits: torch.Tensor, teacher_logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
67
+ teacher_probs = F.softmax(teacher_logits, dim=-1, dtype=torch.float32)
68
+ inf_mask = torch.isinf(logits)
69
+
70
+ logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
71
+ prod_probs = torch.masked_fill(teacher_probs * logprobs, inf_mask, 0)
72
+
73
+ x = torch.sum(prod_probs, dim=-1).view(-1)
74
+ mask = (labels != self.ignore_index).int()
75
+
76
+ distil_loss = -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)
77
+
78
+ return distil_loss
79
+
80
+
81
+ class DPOLoss(nn.Module):
82
+ def __init__(
83
+ self,
84
+ beta: float,
85
+ label_smoothing: float = 0.0,
86
+ ipo: bool = False
87
+ ):
88
+ super().__init__()
89
+ self.beta = beta
90
+ self.label_smoothing = label_smoothing
91
+ self.ipo = ipo
92
+
93
+ def forward(
94
+ self,
95
+ policy_chosen_logps: torch.Tensor,
96
+ policy_reject_logps: torch.Tensor,
97
+ ref_chosen_logps: torch.Tensor,
98
+ ref_reject_logps: torch.Tensor
99
+ ) -> torch.Tensor:
100
+ pi_logratios = policy_chosen_logps - policy_reject_logps
101
+ ref_logratios = ref_chosen_logps - ref_reject_logps
102
+ logits = pi_logratios - ref_logratios
103
+
104
+ if self.ipo:
105
+ losses = (logits - 1 / (2 * self.beta)) ** 2 # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf
106
+ else:
107
+ # Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf)
108
+ losses = (
109
+ -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
110
+ - F.logsigmoid(-self.beta * logits) * self.label_smoothing
111
+ )
112
+
113
+ loss = losses.mean()
114
+
115
+ # chosen_rewards = self.beta * (policy_chosen_probs - ref_chosen_probs).detach()
116
+ # rejected_rewards = self.beta * (policy_reject_probs - ref_reject_probs).detach()
117
+
118
+ return loss
119
+
120
+
121
+ class PPOLoss(nn.Module):
122
+ """
123
+ PPO (Proximal Policy Optimization) 损失函数。
124
+ 这个类统一计算 Actor 和 Value 的损失。
125
+ """
126
+
127
+ def __init__(
128
+ self,
129
+ clip_eps: float,
130
+ vf_coef: float,
131
+ ):
132
+ """
133
+ 初始化PPO损失函数。
134
+ :param clip_eps: PPO裁剪范围的epsilon值。
135
+ :param vf_coef: 价值函数损失的系数。
136
+ """
137
+ super().__init__()
138
+ self.clip_eps = clip_eps
139
+ self.vf_coef = vf_coef
140
+
141
+ def forward(
142
+ self,
143
+ log_probs: torch.Tensor,
144
+ old_log_probs: torch.Tensor,
145
+ values: torch.Tensor,
146
+ old_values: torch.Tensor,
147
+ returns: torch.Tensor,
148
+ advantages: torch.Tensor,
149
+ mask: torch.Tensor
150
+ ):
151
+ """
152
+ 计算PPO的总损失、Actor损失和Value损失。
153
+
154
+ :param log_probs: 当前策略的log probabilities, 形状: [batch_size, seq_len]
155
+ :param old_log_probs: 生成rollout时的旧策略的log probabilities, 形状: [batch_size, seq_len]
156
+ :param values: 当前评论家模型输出的价值, 形状: [batch_size, seq_len]
157
+ :param old_values: 生成rollout时的旧价值, 形状: [batch_size, seq_len]
158
+ :param returns: GAE计算出的回报, 形状: [batch_size, seq_len]
159
+ :param advantages: GAE计算出的优势, 形状: [batch_size, seq_len]
160
+ :param mask: 掩码,只计算生成部分的损失, 形状: [batch_size, seq_len]
161
+ :return: (总损失, Actor损失, Value损失, Entropy)
162
+ """
163
+ # Value Loss (价值损失) with clipping
164
+ values_clipped = old_values + torch.clamp(values - old_values, -self.clip_eps, self.clip_eps)
165
+ vf_loss_unclipped = F.mse_loss(values, returns, reduction='none')
166
+ vf_loss_clipped = F.mse_loss(values_clipped, returns, reduction='none')
167
+ value_loss = torch.max(vf_loss_unclipped, vf_loss_clipped)
168
+ # Apply mask and average
169
+ value_loss = 0.5 * (value_loss * mask).sum() / mask.sum().clamp(min=1.0)
170
+ value_loss = value_loss * self.vf_coef
171
+
172
+ # Actor Loss (策略损失)
173
+ # 计算新旧策略的概率比 r_t = exp(log_prob_new - log_prob_old)
174
+ # ratio 形状: [batch_size, seq_len]
175
+ ratio = torch.exp(log_probs - old_log_probs)
176
+
177
+ # PPO裁剪替代目标(Clipped Surrogate Objective)
178
+ # surr1 形状: [batch_size, seq_len]
179
+ surr1 = ratio * advantages
180
+ # surr2 形状: [batch_size, seq_len]
181
+ surr2 = torch.clamp(ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps) * advantages
182
+
183
+ # 取两者中较小的一个,并加负号(因为我们要最大化这个目标,所以最小化它的负值)
184
+ # 我们只关心生成部分(由mask标记)的损失
185
+ actor_loss = -torch.sum(torch.min(surr1, surr2) * mask) / torch.sum(mask).clamp(min=1.0)
186
+
187
+ # 总损失
188
+ total_loss = actor_loss + value_loss
189
+
190
+ with torch.no_grad():
191
+ # 计算近似KL散度
192
+ logratios = log_probs - old_log_probs
193
+ approx_kl = torch.sum(((torch.exp(logratios) - 1) - logratios) * mask) / mask.sum().clamp(min=1.0)
194
+
195
+ # 计算裁剪比例
196
+ clipped = ratio.gt(1.0 + self.clip_eps) | ratio.lt(1.0 - self.clip_eps)
197
+ clip_frac = torch.sum(clipped.float() * mask) / mask.sum().clamp(min=1.0)
198
+
199
+ return total_loss, actor_loss, value_loss, approx_kl, clip_frac
200
+
201
+
202
+ class GRPOLoss(nn.Module):
203
+ def __init__(
204
+ self,
205
+ beta: float,
206
+ clip_eps_low: float,
207
+ clip_eps_high: Optional[float] = None,
208
+ delta: Optional[float] = None,
209
+ importance_sampling_level: str = 'token',
210
+ loss_type: str = 'grpo',
211
+ gen_max_new_tokens: Optional[float] = None
212
+ ):
213
+ super().__init__()
214
+
215
+ self.beta = beta
216
+ self.clip_eps_low = clip_eps_low
217
+ self.clip_eps_high = clip_eps_high if clip_eps_high else clip_eps_low
218
+ self.delta = delta
219
+ self.importance_sampling_level = importance_sampling_level
220
+ self.loss_type = loss_type
221
+ self.gen_max_new_tokens = gen_max_new_tokens
222
+
223
+ def forward(
224
+ self,
225
+ log_probs: torch.Tensor,
226
+ old_log_probs: torch.Tensor,
227
+ ref_log_probs: torch.Tensor,
228
+ completion_mask: torch.Tensor,
229
+ advantages: torch.Tensor
230
+ ) -> torch.Tensor:
231
+
232
+ if self.beta != 0.0:
233
+ per_token_kl = torch.exp(ref_log_probs - log_probs) - (ref_log_probs - log_probs) - 1
234
+ else:
235
+ per_token_kl = None
236
+
237
+ log_ratio = log_probs - old_log_probs
238
+ if self.importance_sampling_level == "seq":
239
+ # GSPO
240
+ log_importance_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)
241
+ log_importance_weights = log_importance_weights.unsqueeze(-1)
242
+ else:
243
+ # GRPO
244
+ log_importance_weights = log_ratio
245
+
246
+ coef_1 = torch.exp(log_importance_weights)
247
+ coef_2 = torch.clamp(coef_1, 1 - self.clip_eps_low, 1 + self.clip_eps_high)
248
+
249
+ # Two-sided clipping
250
+ if self.delta is not None:
251
+ coef_1 = torch.clamp(coef_1, max=self.delta)
252
+
253
+ per_token_loss1 = coef_1 * advantages
254
+ per_token_loss2 = coef_2 * advantages
255
+ per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
256
+
257
+ if self.beta != 0.0:
258
+ per_token_loss = per_token_loss + self.beta * per_token_kl
259
+
260
+ if self.loss_type == "bnpo":
261
+ loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0)
262
+ elif self.loss_type == "dr_grpo":
263
+ assert self.gen_max_new_tokens is not None
264
+ loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.gen_max_new_tokens)
265
+ else:
266
+ loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean()
267
+
268
+ return loss
@@ -0,0 +1,220 @@
1
+ import os
2
+ from typing import Optional, Tuple
3
+ from abc import ABC, abstractmethod
4
+
5
+ import torch
6
+ from torch import nn
7
+ import torch.distributed as dist
8
+ from torch.utils.data import Dataset, DataLoader
9
+ from torch.utils.data.distributed import DistributedSampler
10
+ from torch.nn.parallel import DistributedDataParallel as DDP
11
+
12
+ try:
13
+ import deepspeed
14
+ except: ...
15
+
16
+ from .log import Logger
17
+
18
+
19
+ class Parallel(ABC):
20
+ def __init__(
21
+ self,
22
+ _init_process_group: bool = True,
23
+ _use_parallel: bool = True
24
+ ):
25
+ self._initialize(_init_process_group, _use_parallel)
26
+
27
+ def _initialize(
28
+ self,
29
+ _init_process_group: bool,
30
+ _use_parallel: bool
31
+ ):
32
+ self._global_rank: int = int(os.environ.get('RANK', -1))
33
+ self._local_rank: int = int(os.environ.get('LOCAL_RANK', -1))
34
+ self._use_parallel: bool = _use_parallel and self._global_rank != -1
35
+
36
+ self._sampler: Optional[DistributedSampler] = None
37
+ self.model: Optional[nn.Module] = None
38
+
39
+ try:
40
+ torch.set_float32_matmul_precision('high')
41
+ torch.backends.cuda.matmul.allow_tf32 = True
42
+ torch.backends.cudnn.allow_tf32 = True
43
+ except:
44
+ pass
45
+
46
+ if self._use_parallel:
47
+ if _init_process_group:
48
+ dist.init_process_group(backend='nccl')
49
+
50
+ self.device: str = f'cuda:{self._local_rank}'
51
+ self.device_type: str = 'cuda'
52
+ torch.cuda.set_device(self.device)
53
+
54
+ Logger.std_log(f'global_rank={self._global_rank}, local_rank={self._local_rank}, world_size={self.world_size}')
55
+ else:
56
+ device = "cpu"
57
+ if torch.cuda.is_available():
58
+ device = "cuda"
59
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
60
+ device = "mps"
61
+
62
+ self.device: str = device
63
+ self.device_type: str = device
64
+
65
+ @abstractmethod
66
+ def process(
67
+ self,
68
+ model: nn.Module,
69
+ optimizer: torch.optim.Optimizer,
70
+ kwargs: Optional[dict] = None,
71
+ save_instance: bool = True
72
+ ) -> Tuple[nn.Module, torch.optim.Optimizer]: ...
73
+
74
+ def process_dataloader(
75
+ self,
76
+ dataset: Dataset,
77
+ data_loader_kwargs: dict,
78
+ sampler_kwargs: Optional[dict]=None
79
+ ) -> DataLoader:
80
+ """
81
+ :param dataset:
82
+ :param data_loader_kwargs
83
+ "batch_size" int,
84
+ "pin_memory" bool,
85
+ "collate_fn" collate_fn,
86
+ "num_workers" int
87
+ "shuffle" bool
88
+ "drop_last" bool
89
+ :param sampler_kwargs:
90
+ "shuffle" bool
91
+ "drop_last" bool
92
+ :return:
93
+ """
94
+
95
+ if self._use_parallel:
96
+ self._sampler = DistributedSampler(dataset=dataset, **sampler_kwargs)
97
+ return DataLoader(dataset=dataset, sampler=self._sampler, **data_loader_kwargs)
98
+
99
+ return DataLoader(dataset=dataset, **data_loader_kwargs)
100
+
101
+ def on_epoch_start(self, epoch):
102
+ if self._sampler:
103
+ self._sampler.set_epoch(epoch)
104
+
105
+ def on_epoch_end(self, epoch): ...
106
+
107
+ def synchronize(self):
108
+ if self._use_parallel:
109
+ torch.cuda.synchronize(device=self.device)
110
+
111
+ def destroy(self):
112
+ if self._use_parallel:
113
+ dist.destroy_process_group()
114
+
115
+ @property
116
+ def parallel_train(self) -> bool:
117
+ return self._use_parallel
118
+
119
+ @property
120
+ def is_main_process(self) -> bool:
121
+ if self._use_parallel:
122
+ return self._global_rank == 0
123
+
124
+ return True
125
+
126
+ @property
127
+ def world_size(self) -> int:
128
+ if self._use_parallel:
129
+ return dist.get_world_size()
130
+ return 1
131
+
132
+ def wait(self, msg=None):
133
+ if self.world_size == 1:
134
+ return
135
+
136
+ msg = f' for {msg}' if msg else ''
137
+ Logger.std_log(f'wait at {self.device}{msg}')
138
+ dist.barrier()
139
+ Logger.std_log(f'continue at {self.device}{msg}')
140
+
141
+
142
+ class DsParallel(Parallel):
143
+ def __init__(self):
144
+ deepspeed.init_distributed(dist_backend='nccl')
145
+ super().__init__(_init_process_group=False)
146
+
147
+ def process(
148
+ self,
149
+ model: nn.Module,
150
+ optimizer: torch.optim.Optimizer,
151
+ kwargs: Optional[dict] = None,
152
+ save_instance: bool = True
153
+ ) -> Tuple[nn.Module, torch.optim.Optimizer]:
154
+ """
155
+ :param model:
156
+ :param optimizer:
157
+ :param kwargs:
158
+ 参考deepspeed配置
159
+ :param save_instance
160
+ :return:
161
+ """
162
+ model, optim, _, _ = deepspeed.initialize(
163
+ model=model,
164
+ optimizer=optimizer,
165
+ dist_init_required=False,
166
+ config_params=kwargs
167
+ )
168
+
169
+ if save_instance:
170
+ self.model = model
171
+
172
+ return model, optim
173
+
174
+ def synchronize(self): ...
175
+
176
+ def destroy(self): ...
177
+
178
+
179
+ class DdpParallel(Parallel):
180
+ def __init__(self):
181
+ super().__init__()
182
+
183
+ def process(
184
+ self,
185
+ model: nn.Module,
186
+ optimizer: torch.optim.Optimizer,
187
+ kwargs: Optional[dict] = None,
188
+ save_instance: bool = True
189
+ ) -> Tuple[nn.Module, torch.optim.Optimizer]:
190
+ model.to(self.device)
191
+
192
+ if self._use_parallel:
193
+ # self.model = DDP(module=model, broadcast_buffers=False, find_unused_parameters=True)
194
+ model = DDP(module=model, device_ids=[self._local_rank], output_device=self._local_rank)
195
+ else:
196
+ model = model
197
+
198
+ if save_instance:
199
+ self.model = model
200
+
201
+ return model, optimizer
202
+
203
+
204
+ class NoneParallel(Parallel):
205
+ def __init__(self):
206
+ super().__init__(_use_parallel=False)
207
+
208
+ def process(
209
+ self,
210
+ model: nn.Module,
211
+ optimizer: torch.optim.Optimizer,
212
+ kwargs: Optional[dict] = None,
213
+ save_instance: bool = True
214
+ ) -> Tuple[nn.Module, torch.optim.Optimizer]:
215
+ model.to(self.device)
216
+
217
+ if save_instance:
218
+ self.model = model
219
+
220
+ return model, optimizer
@@ -0,0 +1,219 @@
1
+ from typing import Optional
2
+ from contextlib import contextmanager
3
+ import itertools
4
+ from packaging import version
5
+ from torch import nn
6
+ from torch.nn.parallel import DistributedDataParallel as DDP
7
+ import torch.distributed as dist
8
+
9
+ from .tools import TrainerTools
10
+ from .parallel import DsParallel, DdpParallel
11
+
12
+
13
+ @contextmanager
14
+ def unwrap_model_for_generation(model: nn.Module):
15
+ """
16
+ Context manager to unwrap distributed or accelerated models for generation tasks.
17
+
18
+ Args:
19
+ model:
20
+ Model to be unwrapped.
21
+ Yields:
22
+ Unwrapped model.
23
+
24
+ Example:
25
+ ```python
26
+ with unwrap_model_for_generation(model, accelerator) as unwrapped_model:
27
+ generated_outputs = unwrapped_model.generate(input_ids)
28
+ ```
29
+ """
30
+ if isinstance(TrainerTools().parallel, DsParallel):
31
+ import deepspeed
32
+ assert isinstance(model, deepspeed.DeepSpeedEngine)
33
+
34
+ if model.zero_optimization_stage() == 3:
35
+ with deepspeed.zero.GatheredParameters(model.parameters()):
36
+ _remove_hooks(model)
37
+ yield unwrap_model(model)
38
+ _add_hooks(model)
39
+ else:
40
+ yield unwrap_model(model)
41
+ elif isinstance(TrainerTools().parallel, DdpParallel):
42
+ yield unwrap_model(model)
43
+ else:
44
+ yield model
45
+
46
+
47
+ def sync_model_params(_from: nn.Module, _to: Optional[nn.Module], mixup_alpha: float = 1.0):
48
+ """
49
+ 必须在所有rank上调用,非rank0, _to 可以设置为None.
50
+ 当前函数不适用于_to是一个zero3模型
51
+ """
52
+ if isinstance(TrainerTools().parallel, DsParallel):
53
+ state_dict = _get_ds_model_params(_from, only_rank0=_to is None)
54
+ elif isinstance(_from, DDP):
55
+ state_dict = _from.module.state_dict()
56
+ else:
57
+ state_dict = _from.state_dict()
58
+
59
+ if not _to or not state_dict:
60
+ return
61
+
62
+ unwrap_to_model = unwrap_model(_to)
63
+ if mixup_alpha == 1.0:
64
+ # 直接覆盖
65
+ unwrap_to_model.load_state_dict(state_dict, strict=False)
66
+ else:
67
+ # 混合参数
68
+ for param_name, target_param in unwrap_to_model.named_parameters():
69
+ if param_name in state_dict:
70
+ from_param_tensor = state_dict[param_name]
71
+ target_param.data.mul_(1.0 - mixup_alpha).add_(
72
+ from_param_tensor.data.to(target_param.device),
73
+ alpha=mixup_alpha
74
+ )
75
+
76
+
77
+ def unwrap_model(model) -> nn.Module:
78
+ try:
79
+ import deepspeed
80
+ if isinstance(model, deepspeed.DeepSpeedEngine):
81
+ return model.module
82
+ except: ...
83
+
84
+ if isinstance(model, DDP):
85
+ return model.module
86
+
87
+ return model
88
+
89
+
90
+ def _get_ds_full_state_dict_on_rank0(model: nn.Module) -> Optional[dict]:
91
+ """
92
+ 需要在所有rank上调用,然后只有rank0有值
93
+ """
94
+ import deepspeed
95
+ assert isinstance(model, deepspeed.DeepSpeedEngine)
96
+
97
+ if model.zero_optimization_stage() != 3:
98
+ if TrainerTools().parallel.is_main_process:
99
+ return {k: v.cpu().clone() for k, v in model.module.state_dict().items()}
100
+ return None
101
+
102
+ # --- ZeRO-3 ---
103
+ # 只调用一次 GatheredParameters,传入所有参数
104
+ with deepspeed.zero.GatheredParameters(model.parameters(), modifier_rank=0):
105
+ if TrainerTools().parallel.is_main_process:
106
+ # 在这个 'with' 代码块内,rank 0 上的 model.module 拥有完整的参数
107
+ # 所以我们可以像操作普通模型一样直接调用 state_dict()
108
+ full_state_dict = model.module.state_dict()
109
+
110
+ # 将其克隆到 CPU 并返回
111
+ return {k: v.cpu().clone() for k, v in full_state_dict.items()}
112
+
113
+ # 其他 rank 执行到这里时,上下文结束,直接返回 None
114
+ return None
115
+
116
+
117
+ def _get_ds_model_params(model: nn.Module, only_rank0=False):
118
+ """
119
+ 从一个正在运行的 DeepSpeedEngine 中高效地提取完整的 FP32 state_dict,
120
+ 兼容 ZeRO Stages 0, 1, 2, 3。
121
+ 包含了对 ZeRO-3 中分片参数的正确处理。
122
+ """
123
+
124
+ import deepspeed
125
+ assert isinstance(model, deepspeed.DeepSpeedEngine)
126
+ state_dict = _get_ds_full_state_dict_on_rank0(model)
127
+
128
+ # 现在,只有 rank 0 上的 state_dict 是一个有效的字典,其他 rank 上是 None。
129
+ # 我们需要将其广播给所有进程。
130
+ if not only_rank0 and TrainerTools().parallel.world_size > 1:
131
+ # 准备一个列表,rank 0 有数据,其他 rank 是占位符
132
+ object_list = [state_dict] if TrainerTools().parallel.is_main_process else [None]
133
+ # 执行广播,这个操作是阻塞的,会同步所有进程
134
+ dist.broadcast_object_list(object_list, src=0)
135
+ # 所有进程从列表中获取广播后的 state_dict 副本
136
+ state_dict = object_list[0]
137
+
138
+ return state_dict
139
+
140
+
141
+ def _copy_params(model, target_model, mixup_alpha):
142
+ for target_param, copy_param in zip(target_model.parameters(), model.parameters()):
143
+ target_param.data.mul_(1.0 - mixup_alpha).add_(copy_param.data, alpha=mixup_alpha)
144
+
145
+
146
+ def _sync_ds_model_params(_from: nn.Module, _to: Optional[nn.Module], mixup_alpha: float = 1.0):
147
+ import deepspeed
148
+ assert isinstance(_from, deepspeed.DeepSpeedEngine)
149
+
150
+ origin_from = unwrap_model(_from)
151
+
152
+ if _from.zero_optimization_stage() == 3:
153
+ with deepspeed.zero.GatheredParameters(list(origin_from.parameters()) + list(_to.parameters()), modifier_rank=0):
154
+ # why only rank 0?
155
+ if TrainerTools().parallel.is_main_process:
156
+ _copy_params(origin_from, _to, mixup_alpha)
157
+ else:
158
+ _copy_params(origin_from, _to, mixup_alpha)
159
+
160
+
161
+ def _sync_ddp_model_params(_from: nn.Module, _to: Optional[nn.Module], mixup_alpha: float = 1.0):
162
+ assert isinstance(_from, DDP)
163
+
164
+ origin_from = unwrap_model(_from)
165
+ _copy_params(origin_from, _to, mixup_alpha)
166
+
167
+
168
+ def _add_hooks(model: nn.Module) -> None:
169
+ """Adds the optimizer hooks from a DeepSpeed ZeRO-3 model."""
170
+ import deepspeed
171
+ assert isinstance(model, deepspeed.DeepSpeedEngine)
172
+
173
+ if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer
174
+ return
175
+ if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
176
+ optimizer_offload = model.optimizer.parameter_offload
177
+ elif model.optimizer is not None:
178
+ optimizer_offload = model.optimizer
179
+ else:
180
+ raise RuntimeError("The model optimizer is None, which is not yet supported.")
181
+ if version.parse(deepspeed.__version__) >= version.parse("0.16.4"):
182
+ # Account for renaming in https://github.com/deepspeedai/DeepSpeed/pull/6847
183
+ optimizer_offload._register_deepspeed_module(optimizer_offload.module)
184
+ else:
185
+ optimizer_offload._register_hooks_recursively(optimizer_offload.module)
186
+
187
+
188
+ def _remove_hooks(model: nn.Module) -> None:
189
+ """Removes the optimizer hooks from a DeepSpeed ZeRO-3 model."""
190
+ import deepspeed
191
+ assert isinstance(model, deepspeed.DeepSpeedEngine)
192
+
193
+ if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer
194
+ return
195
+ if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
196
+ optimizer_offload = model.optimizer.parameter_offload
197
+ elif model.optimizer is not None:
198
+ optimizer_offload = model.optimizer
199
+ else:
200
+ raise RuntimeError("The model optimizer is None, which is not yet supported.")
201
+
202
+ for param in _iter_params(optimizer_offload.module, recurse=True):
203
+ param.ds_active_sub_modules.clear()
204
+
205
+ for hook in optimizer_offload.forward_hooks:
206
+ hook.remove()
207
+ for hook in optimizer_offload.backward_hooks:
208
+ hook.remove()
209
+
210
+ optimizer_offload.forward_hooks = []
211
+ optimizer_offload.backward_hooks = []
212
+
213
+
214
+ def _iter_params(module, recurse=False):
215
+ return [param for _, param in _get_all_parameters(module, recurse)]
216
+
217
+
218
+ def _get_all_parameters(sub_module, recurse=False):
219
+ return itertools.chain(sub_module.named_parameters(recurse=recurse), sub_module.ds_external_parameters())