project-llm-trainer 0.13.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/__init__.py +13 -0
- llm_trainer/base_trainer.py +707 -0
- llm_trainer/checkpoint.py +114 -0
- llm_trainer/dataset.py +335 -0
- llm_trainer/dpo_trainer.py +311 -0
- llm_trainer/ds_checkpoint.py +72 -0
- llm_trainer/eval.py +33 -0
- llm_trainer/generate_utils.py +463 -0
- llm_trainer/grpo_trainer.py +410 -0
- llm_trainer/log.py +65 -0
- llm_trainer/loss.py +266 -0
- llm_trainer/parallel.py +220 -0
- llm_trainer/partition_utils.py +219 -0
- llm_trainer/ppo_trainer.py +686 -0
- llm_trainer/scheduler.py +220 -0
- llm_trainer/sft_trainer.py +97 -0
- llm_trainer/tokenizer.py +162 -0
- llm_trainer/tools.py +116 -0
- llm_trainer/train_configs.py +327 -0
- llm_trainer/trainer.py +34 -0
- llm_trainer/utils.py +630 -0
- project_llm_trainer-0.13.4.data/scripts/calc_intermediate_size +15 -0
- project_llm_trainer-0.13.4.data/scripts/ddp_train +21 -0
- project_llm_trainer-0.13.4.data/scripts/ds_train +17 -0
- project_llm_trainer-0.13.4.data/scripts/py_train +12 -0
- project_llm_trainer-0.13.4.data/scripts/smart_train +37 -0
- project_llm_trainer-0.13.4.data/scripts/vis_log +98 -0
- project_llm_trainer-0.13.4.data/scripts/vis_lr +46 -0
- project_llm_trainer-0.13.4.dist-info/METADATA +9 -0
- project_llm_trainer-0.13.4.dist-info/RECORD +32 -0
- project_llm_trainer-0.13.4.dist-info/WHEEL +5 -0
- project_llm_trainer-0.13.4.dist-info/top_level.txt +1 -0
llm_trainer/utils.py
ADDED
|
@@ -0,0 +1,630 @@
|
|
|
1
|
+
import random
|
|
2
|
+
from contextlib import nullcontext
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
import torch.distributed as dist
|
|
6
|
+
from torch.nn.utils.rnn import pad_sequence
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
from .tools import TrainerTools
|
|
9
|
+
import numpy as np
|
|
10
|
+
from typing import Union, List, Optional, Tuple
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def set_seed(seed=42):
|
|
14
|
+
random.seed(seed)
|
|
15
|
+
np.random.seed(seed)
|
|
16
|
+
torch.manual_seed(seed)
|
|
17
|
+
torch.cuda.manual_seed(seed)
|
|
18
|
+
torch.cuda.manual_seed_all(seed)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def autocast(device_type):
|
|
22
|
+
if TrainerTools().use_amp:
|
|
23
|
+
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
|
|
24
|
+
return torch.autocast(
|
|
25
|
+
device_type=device_type,
|
|
26
|
+
dtype=dtype,
|
|
27
|
+
enabled=True,
|
|
28
|
+
cache_enabled=None
|
|
29
|
+
)
|
|
30
|
+
else:
|
|
31
|
+
return nullcontext()
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def create_doc_boundary_mask(
|
|
36
|
+
input_ids: torch.Tensor,
|
|
37
|
+
dtype: torch.dtype
|
|
38
|
+
) -> torch.Tensor:
|
|
39
|
+
"""
|
|
40
|
+
根据文档结束符 (eot) 的位置,创建一个 attention mask 来阻止跨文档的注意力。
|
|
41
|
+
|
|
42
|
+
这个函数生成的 mask 会阻止一个 token 关注 (attend to) 属于前面文档的 tokens。
|
|
43
|
+
例如,对于输入 `[[1, 2, eot, 3, 4, eot]]`,
|
|
44
|
+
tokens `3` 和 `4` 将无法关注 `1`, `2`, 和第一个 `eot`。
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
input_ids (torch.Tensor): 输入的 token ID 张量,形状为 (bsz, seq_len)。
|
|
48
|
+
dtype (torch.dtype): 数据类型。
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
torch.Tensor: 符合 attention 机制要求的 mask 张量,
|
|
52
|
+
形状为 (bsz, 1, seq_len, seq_len)。
|
|
53
|
+
值为 -inf 的位置表示被屏蔽,值为 0 的位置表示允许注意力。
|
|
54
|
+
"""
|
|
55
|
+
# 获取 batch size 和 sequence length
|
|
56
|
+
bsz, seq_len = input_ids.shape
|
|
57
|
+
|
|
58
|
+
# 1. 确定每个 eot_token 的位置
|
|
59
|
+
# is_eot 是一个布尔张量,形状为 (bsz, seq_len)
|
|
60
|
+
is_eot = (input_ids == TrainerTools().tokenizer.end)
|
|
61
|
+
|
|
62
|
+
# 2. 为每个 token 分配一个文档 ID
|
|
63
|
+
# 我们使用 cumsum (累加和) 来创建递增的文档 ID。一个 token 所属的文档 ID,
|
|
64
|
+
# 取决于它前面有多少个 eot。
|
|
65
|
+
# 示例:
|
|
66
|
+
# input_ids: [[1, 2, 3, eot, 4, 5, eot]]
|
|
67
|
+
# is_eot: [F, F, F, T, F, F, T] -> [0, 0, 0, 1, 0, 0, 1]
|
|
68
|
+
# doc_ids_ending: [0, 0, 0, 1, 1, 1, 2] (cumsum 的结果)
|
|
69
|
+
# doc_ids: [0, 0, 0, 0, 1, 1, 1] (向右移位后的结果)
|
|
70
|
+
# 这个结果正确地将文档 0 分配给了前四个 token,将文档 1 分配给了后三个 token。
|
|
71
|
+
doc_ids_ending = torch.cumsum(is_eot, dim=-1)
|
|
72
|
+
doc_ids = F.pad(doc_ids_ending[:, :-1], (1, 0), value=0)
|
|
73
|
+
|
|
74
|
+
# 3. 通过比较 query 和 key 的文档 ID 来创建 mask
|
|
75
|
+
# 我们的目标是:当 query token 所在的文档 ID 大于 key token 所在的文档 ID 时,进行屏蔽。
|
|
76
|
+
# query_doc_ids 形状: (bsz, seq_len, 1)
|
|
77
|
+
# key_doc_ids 形状: (bsz, 1, seq_len)
|
|
78
|
+
query_doc_ids = doc_ids.unsqueeze(2)
|
|
79
|
+
key_doc_ids = doc_ids.unsqueeze(1)
|
|
80
|
+
|
|
81
|
+
# 利用 PyTorch 的广播机制,`query_doc_ids > key_doc_ids` 会创建一个
|
|
82
|
+
# 形状为 (bsz, seq_len, seq_len) 的布尔张量。
|
|
83
|
+
# 当 query 的文档 ID 大于 key 的文档 ID 时,值为 True,这正是我们需要屏蔽的位置。
|
|
84
|
+
boundary_mask = query_doc_ids > key_doc_ids
|
|
85
|
+
|
|
86
|
+
# 4. 将布尔 mask 转换为 attention 机制所需的浮点数 mask (-inf 和 0)
|
|
87
|
+
final_mask = torch.zeros(
|
|
88
|
+
(bsz, seq_len, seq_len), device=input_ids.device, dtype=dtype
|
|
89
|
+
)
|
|
90
|
+
final_mask.masked_fill_(boundary_mask, torch.finfo(dtype).min)
|
|
91
|
+
|
|
92
|
+
# 5. 增加一个维度以匹配 attention head 的输入要求 (bsz, num_heads, seq_len, seq_len)
|
|
93
|
+
# 这里我们只生成一个 mask,它可以被广播到所有的 head。
|
|
94
|
+
return final_mask.unsqueeze(1)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def generate_position_ids(input_ids: torch.Tensor):
|
|
98
|
+
"""
|
|
99
|
+
为打包序列生成 position_ids 张量。
|
|
100
|
+
|
|
101
|
+
参数:
|
|
102
|
+
input_ids (torch.Tensor): 输入的 token ID 张量 (batch_size, sequence_length)。
|
|
103
|
+
end_of_text_id (int): 代表文本结束的特殊 token ID。
|
|
104
|
+
|
|
105
|
+
返回:
|
|
106
|
+
torch.Tensor: 生成的 position_ids 张量。
|
|
107
|
+
"""
|
|
108
|
+
# 获取输入张量的形状
|
|
109
|
+
batch_size, seq_length = input_ids.shape
|
|
110
|
+
|
|
111
|
+
# 创建一个与输入形状相同,全为0的张量来存储position_ids
|
|
112
|
+
# 第一个token的位置永远是0,所以这个初始化是正确的
|
|
113
|
+
position_ids = torch.zeros_like(input_ids, dtype=torch.long)
|
|
114
|
+
|
|
115
|
+
# 从第二个时间步 (t=1) 开始遍历整个序列
|
|
116
|
+
for t in range(1, seq_length):
|
|
117
|
+
# 检查前一个时间步 (t-1) 的token是否为 EOT token
|
|
118
|
+
# 这会为批次中的每个序列生成一个布尔值
|
|
119
|
+
is_reset_token = (input_ids[:, t - 1] == TrainerTools().tokenizer.end)
|
|
120
|
+
|
|
121
|
+
# 获取前一个时间步的位置ID
|
|
122
|
+
prev_position_ids = position_ids[:, t - 1]
|
|
123
|
+
|
|
124
|
+
# 如果前一个token是EOT,当前位置重置为0;否则,在前一个位置上加1
|
|
125
|
+
# torch.where 会根据 is_reset_token 的布尔值进行选择
|
|
126
|
+
position_ids[:, t] = torch.where(is_reset_token, 0, prev_position_ids + 1)
|
|
127
|
+
|
|
128
|
+
return position_ids
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def calc_position_ids(attention_mask: torch.Tensor) -> torch.Tensor:
|
|
132
|
+
"""
|
|
133
|
+
根据 attention_mask 计算 position_ids,主要用于 Left Padding 场景。
|
|
134
|
+
mask: [0, 0, 1, 1, 1] -> position_ids: [0, 0, 0, 1, 2]
|
|
135
|
+
"""
|
|
136
|
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
|
137
|
+
position_ids.masked_fill_(attention_mask == 0, 0)
|
|
138
|
+
return position_ids
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def repeat_image_tok(
|
|
142
|
+
tokens: torch.Tensor,
|
|
143
|
+
tokens_per_image: int,
|
|
144
|
+
attention_mask: Optional[torch.Tensor] = None
|
|
145
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
146
|
+
# tokens_per_image=3 -> <image>...xxxx -> <image><image><image>...xxx
|
|
147
|
+
image_tok = TrainerTools().tokenizer.image
|
|
148
|
+
mask = (tokens == image_tok)
|
|
149
|
+
if not mask.any():
|
|
150
|
+
return tokens, attention_mask
|
|
151
|
+
|
|
152
|
+
# 计算每个位置的重复次数:默认为1,image token 位置为 tokens_per_image
|
|
153
|
+
repeats = torch.ones_like(tokens, dtype=torch.long)
|
|
154
|
+
repeats[mask] = tokens_per_image
|
|
155
|
+
|
|
156
|
+
# 使用 repeat_interleave 进行高效扩展
|
|
157
|
+
new_tokens = torch.repeat_interleave(tokens, repeats, dim=0)
|
|
158
|
+
|
|
159
|
+
if attention_mask is not None:
|
|
160
|
+
# 对 mask 做同样的操作
|
|
161
|
+
new_mask = torch.repeat_interleave(attention_mask, repeats, dim=0)
|
|
162
|
+
return new_tokens, new_mask
|
|
163
|
+
|
|
164
|
+
return new_tokens, None
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def batch_repeat_image_tok(
|
|
168
|
+
tokens: torch.Tensor,
|
|
169
|
+
tokens_per_image: int,
|
|
170
|
+
attention_mask: Optional[torch.Tensor] = None
|
|
171
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
172
|
+
new_tokens_list = []
|
|
173
|
+
new_masks_list = []
|
|
174
|
+
has_mask = attention_mask is not None
|
|
175
|
+
|
|
176
|
+
for i in range(len(tokens)):
|
|
177
|
+
token_seq = tokens[i]
|
|
178
|
+
mask_seq = attention_mask[i] if has_mask else None
|
|
179
|
+
|
|
180
|
+
if has_mask:
|
|
181
|
+
new_tok, new_mask = repeat_image_tok(token_seq, tokens_per_image, mask_seq)
|
|
182
|
+
new_tokens_list.append(new_tok)
|
|
183
|
+
new_masks_list.append(new_mask)
|
|
184
|
+
else:
|
|
185
|
+
new_tok, _ = repeat_image_tok(token_seq, tokens_per_image)
|
|
186
|
+
new_tokens_list.append(new_tok)
|
|
187
|
+
|
|
188
|
+
ret_tokens = torch.stack(new_tokens_list, dim=0)
|
|
189
|
+
if has_mask:
|
|
190
|
+
ret_masks = torch.stack(new_masks_list, dim=0)
|
|
191
|
+
return ret_tokens, ret_masks
|
|
192
|
+
|
|
193
|
+
return ret_tokens, None
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def pretrain_collate_fn(batch_data):
|
|
197
|
+
# [[x,x,x], [y,y,y]]
|
|
198
|
+
inputs = pad_sequence(batch_data, batch_first=True, padding_value=TrainerTools().tokenizer.pad)
|
|
199
|
+
# crossEntropy默认的ignore_index是-100
|
|
200
|
+
labels = pad_sequence(batch_data, batch_first=True, padding_value=-100)
|
|
201
|
+
|
|
202
|
+
# inputs, labels
|
|
203
|
+
return {
|
|
204
|
+
'inputs': inputs,
|
|
205
|
+
'labels': labels
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def get_sft_collate_fn(mask_prompt: bool):
|
|
210
|
+
def sft_collate_fn(batch_data):
|
|
211
|
+
"""
|
|
212
|
+
如果是sft,则不计算prompt部分的loss, 例如:
|
|
213
|
+
logits: [USER]你好[BOT]我好[SEP]
|
|
214
|
+
labels: [USER]你好[BOT]我好[SEP]
|
|
215
|
+
|
|
216
|
+
shift_logits: [USER]你好[BOT]我好
|
|
217
|
+
shift_labels: 你好[BOT]我好[SEP]
|
|
218
|
+
|
|
219
|
+
mask_labels: mask mask mask mask 我好[SEP]
|
|
220
|
+
* mask=-100和pad一样
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
多轮对话场景
|
|
224
|
+
[USER]你好[BOT]我好[SEP][USER]很好[BOT]不好[SEP]
|
|
225
|
+
mask: mask mask mask mask 我好[SEP] mask mask mask mask 不好[SEP]
|
|
226
|
+
"""
|
|
227
|
+
batch_train_data = []
|
|
228
|
+
image_tags = []
|
|
229
|
+
for item in batch_data:
|
|
230
|
+
batch_train_data.append(item['inputs'])
|
|
231
|
+
image_tags.append(item['image_tag'])
|
|
232
|
+
|
|
233
|
+
# [[x,x,x], [y,y,y]]
|
|
234
|
+
inputs = pad_sequence(batch_train_data, batch_first=True, padding_value=TrainerTools().tokenizer.pad)
|
|
235
|
+
# crossEntropy默认的ignore_index是-100
|
|
236
|
+
labels = pad_sequence(batch_train_data, batch_first=True, padding_value=-100)
|
|
237
|
+
|
|
238
|
+
if mask_prompt:
|
|
239
|
+
labels = _mask_prompt(labels)
|
|
240
|
+
|
|
241
|
+
return {
|
|
242
|
+
'inputs': inputs,
|
|
243
|
+
'labels': labels,
|
|
244
|
+
'image_tags': image_tags
|
|
245
|
+
}
|
|
246
|
+
|
|
247
|
+
return sft_collate_fn
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def get_dpo_collate_fn(mask_prompt: bool):
|
|
251
|
+
def dpo_collate_fn(batch_data):
|
|
252
|
+
# batch_data: [{'chosen': chosen, 'rejected': rejected}, {'chosen': chosen, 'rejected': rejected}]
|
|
253
|
+
chosen_inputs = []
|
|
254
|
+
chosen_labels = []
|
|
255
|
+
rejected_inputs = []
|
|
256
|
+
rejected_labels = []
|
|
257
|
+
|
|
258
|
+
max_len = 0
|
|
259
|
+
for key in ['chosen', 'rejected']:
|
|
260
|
+
max_len = max(max(len(item[key]) for item in batch_data), max_len)
|
|
261
|
+
|
|
262
|
+
for item in batch_data:
|
|
263
|
+
chosen_sequence = item['chosen']
|
|
264
|
+
chosen_inputs.append(chosen_sequence + [TrainerTools().tokenizer.pad] * (max_len - len(chosen_sequence)))
|
|
265
|
+
chosen_labels.append(chosen_sequence + [-100] * (max_len - len(chosen_sequence)))
|
|
266
|
+
|
|
267
|
+
rejected_sequence = item['rejected']
|
|
268
|
+
rejected_inputs.append(rejected_sequence + [TrainerTools().tokenizer.pad] * (max_len - len(rejected_sequence)))
|
|
269
|
+
rejected_labels.append(rejected_sequence + [-100] * (max_len - len(rejected_sequence)))
|
|
270
|
+
|
|
271
|
+
chosen_inputs = torch.tensor(chosen_inputs).long()
|
|
272
|
+
chosen_labels = torch.tensor(chosen_labels).long()
|
|
273
|
+
if mask_prompt:
|
|
274
|
+
chosen_labels = _mask_prompt(chosen_labels)
|
|
275
|
+
|
|
276
|
+
rejected_inputs = torch.tensor(rejected_inputs).long()
|
|
277
|
+
rejected_labels = torch.tensor(rejected_labels).long()
|
|
278
|
+
if mask_prompt:
|
|
279
|
+
rejected_labels = _mask_prompt(rejected_labels)
|
|
280
|
+
|
|
281
|
+
return {
|
|
282
|
+
'chosen_inputs': chosen_inputs,
|
|
283
|
+
'chosen_labels': chosen_labels,
|
|
284
|
+
'rejected_inputs': rejected_inputs,
|
|
285
|
+
'rejected_labels': rejected_labels
|
|
286
|
+
}
|
|
287
|
+
|
|
288
|
+
return dpo_collate_fn
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def split_batch(data_per_batch: dict) -> list[dict]:
|
|
292
|
+
"""
|
|
293
|
+
from: data_per_batch("sequences": [group_size, max_generate_len] ...)
|
|
294
|
+
to: [dict("sequences": [max_generate_len] ...) ... group_size]
|
|
295
|
+
"""
|
|
296
|
+
|
|
297
|
+
group_size = data_per_batch['sequence_ids'].size(0)
|
|
298
|
+
# [{"sequence_ids": xxx, "old_log_probs": xxx...}, ...]
|
|
299
|
+
group_data = [{} for _ in range(group_size)]
|
|
300
|
+
|
|
301
|
+
keys = (
|
|
302
|
+
'sequence_ids',
|
|
303
|
+
'old_log_probs',
|
|
304
|
+
'ref_log_probs',
|
|
305
|
+
'advantages',
|
|
306
|
+
'attention_mask',
|
|
307
|
+
'mask',
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
for key in keys:
|
|
311
|
+
value = data_per_batch[key]
|
|
312
|
+
if value is None:
|
|
313
|
+
vals = [None] * group_size
|
|
314
|
+
else:
|
|
315
|
+
vals = torch.unbind(value)
|
|
316
|
+
|
|
317
|
+
for i, v in enumerate(vals):
|
|
318
|
+
group_data[i][key] = v
|
|
319
|
+
|
|
320
|
+
return group_data
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def join_batch(batch_data: list[dict]) -> dict:
|
|
324
|
+
"""
|
|
325
|
+
from: [dict("sequences": [max_generate_len] ...), ...]
|
|
326
|
+
to: dict("sequences": max_generate_len, ...)
|
|
327
|
+
"""
|
|
328
|
+
|
|
329
|
+
result = {}
|
|
330
|
+
keys = (
|
|
331
|
+
'sequence_ids',
|
|
332
|
+
'old_log_probs',
|
|
333
|
+
'ref_log_probs',
|
|
334
|
+
'advantages',
|
|
335
|
+
'attention_mask',
|
|
336
|
+
'mask',
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
for key in keys:
|
|
340
|
+
# [sequence_ids, sequence_ids ...]
|
|
341
|
+
# shape [batch_size, seq_len]
|
|
342
|
+
vals = [item[key] for item in batch_data]
|
|
343
|
+
if all(v is not None for v in vals):
|
|
344
|
+
data = _zero_pad_sequences(vals, "left")
|
|
345
|
+
else:
|
|
346
|
+
data = None
|
|
347
|
+
result[key] = data
|
|
348
|
+
|
|
349
|
+
return result
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
# 默认使用torch提供的pad_sequence
|
|
353
|
+
# 如果pad_sequence不支持padding_side参数,则将改参数置为False,使用反转的方式
|
|
354
|
+
_use_origin_pad_sequence = True
|
|
355
|
+
def left_pad_sequence(
|
|
356
|
+
sequences: Union[torch.Tensor, List[torch.Tensor]],
|
|
357
|
+
padding_value: float,
|
|
358
|
+
) -> torch.Tensor:
|
|
359
|
+
global _use_origin_pad_sequence
|
|
360
|
+
|
|
361
|
+
if _use_origin_pad_sequence:
|
|
362
|
+
try:
|
|
363
|
+
return pad_sequence(sequences, batch_first=True, padding_value=padding_value, padding_side='left')
|
|
364
|
+
except TypeError:
|
|
365
|
+
_use_origin_pad_sequence = False
|
|
366
|
+
return left_pad_sequence(sequences, padding_value)
|
|
367
|
+
else:
|
|
368
|
+
# 反转每个序列的顺序(如 [1,2,3] → [3,2,1])
|
|
369
|
+
reversed_sequences = [seq.flip(dims=(0,)) for seq in sequences]
|
|
370
|
+
# 使用默认的右侧填充
|
|
371
|
+
padded_reversed = pad_sequence(reversed_sequences, batch_first=True, padding_value=padding_value)
|
|
372
|
+
# 再次反转序列顺序,恢复原始方向(填充在左侧)
|
|
373
|
+
return padded_reversed.flip(dims=(1,))
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
_use_memory_efficient_log_softmax = True
|
|
377
|
+
def log_softmax(logits, index) -> torch.Tensor:
|
|
378
|
+
if _use_memory_efficient_log_softmax:
|
|
379
|
+
return _selective_log_softmax(logits, index)
|
|
380
|
+
|
|
381
|
+
# Convert raw logits into log probabilities along the vocabulary axis.
|
|
382
|
+
# [batch_size, seq_len, vocab_size]
|
|
383
|
+
log_probs = F.log_softmax(logits, dim=-1)
|
|
384
|
+
|
|
385
|
+
# Reshape input_ids from (batch_size, seq_len) to (batch_size, seq_len, 1) for gathering.
|
|
386
|
+
# Then, gather the log probability for each token in input_ids.
|
|
387
|
+
selected_log_probs = log_probs.gather(dim=-1, index=index.unsqueeze(-1))
|
|
388
|
+
|
|
389
|
+
# Remove the extra last dimension to get back to shape (batch_size, seq_len).
|
|
390
|
+
return selected_log_probs.squeeze(-1)
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True) -> torch.Tensor:
|
|
394
|
+
"""Whiten values with masked values."""
|
|
395
|
+
mean, var = _masked_mean(values, mask), _masked_var(values, mask)
|
|
396
|
+
whitened = (values - mean) * torch.rsqrt(var + 1e-8)
|
|
397
|
+
if not shift_mean:
|
|
398
|
+
whitened += mean
|
|
399
|
+
return whitened
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
def truncate_sequences_at_eos(
|
|
403
|
+
sequences: torch.Tensor,
|
|
404
|
+
eos_token_id: int,
|
|
405
|
+
pad_token_id: int
|
|
406
|
+
) -> torch.Tensor:
|
|
407
|
+
"""
|
|
408
|
+
高效地将批处理中的序列在第一个EOS标记处截断。
|
|
409
|
+
第一个EOS标记之后的所有内容(不包括EOS自身)将被替换为pad_token_id。
|
|
410
|
+
|
|
411
|
+
这是一个向量化的实现,以确保在GPU上的性能。
|
|
412
|
+
它使用 torch.where,因此不依赖于 pad_token_id 必须为0。
|
|
413
|
+
|
|
414
|
+
Args:
|
|
415
|
+
sequences (torch.Tensor): 批处理序列, 形状为 (batch_size, seq_len)。
|
|
416
|
+
eos_token_id (int): 句子结束标记的ID。
|
|
417
|
+
pad_token_id (int): 填充标记的ID。
|
|
418
|
+
|
|
419
|
+
Returns:
|
|
420
|
+
torch.Tensor: 截断后的序列,形状与输入相同。
|
|
421
|
+
"""
|
|
422
|
+
# 创建一个布尔掩码,标记所有EOS token的位置
|
|
423
|
+
eos_mask = (sequences == eos_token_id)
|
|
424
|
+
|
|
425
|
+
# 找到每行中第一个True(即第一个EOS token)的索引
|
|
426
|
+
# .int() 是为了兼容旧版torch,argmax需要非布尔类型
|
|
427
|
+
first_eos_indices = torch.argmax(eos_mask.int(), dim=1)
|
|
428
|
+
|
|
429
|
+
# 检查哪些序列确实包含了EOS token。
|
|
430
|
+
# 如果某一行完全没有EOS, argmax会返回0, 这会产生歧义。
|
|
431
|
+
has_eos = eos_mask.any(dim=1)
|
|
432
|
+
|
|
433
|
+
# 对于没有EOS token的序列,将截断索引设置为序列最大长度,以防错误截断
|
|
434
|
+
first_eos_indices[~has_eos] = sequences.shape[1]
|
|
435
|
+
|
|
436
|
+
# 创建一个 [0, 1, 2, ..., seq_len-1] 的索引张量
|
|
437
|
+
indices_mask = torch.arange(sequences.shape[1], device=sequences.device)
|
|
438
|
+
|
|
439
|
+
# 利用广播机制创建一个掩码,标记所有应保留的token
|
|
440
|
+
# 对于每个序列,当 token_index < first_eos_index 时为True
|
|
441
|
+
keep_mask = indices_mask < first_eos_indices.unsqueeze(1)
|
|
442
|
+
|
|
443
|
+
# 使用 torch.where 进行安全替换
|
|
444
|
+
# 如果 keep_mask 为 True,则保留原始序列的token,否则替换为 pad_token_id
|
|
445
|
+
truncated_sequences = torch.where(
|
|
446
|
+
keep_mask,
|
|
447
|
+
sequences,
|
|
448
|
+
pad_token_id
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
return truncated_sequences
|
|
452
|
+
|
|
453
|
+
|
|
454
|
+
def disable_dropout_in_model(model: torch.nn.Module) -> None:
|
|
455
|
+
for module in model.modules():
|
|
456
|
+
if isinstance(module, torch.nn.Dropout):
|
|
457
|
+
module.p = 0
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
def _masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor:
|
|
461
|
+
"""Compute mean of tensor with a masked values."""
|
|
462
|
+
if axis is not None:
|
|
463
|
+
return (values * mask).sum(axis=axis) / mask.sum(axis=axis)
|
|
464
|
+
else:
|
|
465
|
+
return (values * mask).sum() / mask.sum()
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
def _masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor:
|
|
469
|
+
"""Compute variance of tensor with masked values."""
|
|
470
|
+
mean = _masked_mean(values, mask)
|
|
471
|
+
centered_values = values - mean
|
|
472
|
+
variance = _masked_mean(centered_values**2, mask)
|
|
473
|
+
if unbiased:
|
|
474
|
+
mask_sum = mask.sum()
|
|
475
|
+
if mask_sum == 0:
|
|
476
|
+
return torch.tensor(0.0, device=values.device, dtype=values.dtype)
|
|
477
|
+
|
|
478
|
+
# note that if mask_sum == 1, then there is a division by zero issue
|
|
479
|
+
# to avoid it you just need to use a larger minibatch_size
|
|
480
|
+
bessel_correction = mask_sum / (mask_sum - 1)
|
|
481
|
+
variance = variance * bessel_correction
|
|
482
|
+
return variance
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
def _selective_log_softmax(logits, index) -> torch.Tensor:
|
|
486
|
+
if logits.dtype in [torch.float32, torch.float64]:
|
|
487
|
+
selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
|
|
488
|
+
logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
|
489
|
+
per_token_logps = selected_logits - logsumexp_values
|
|
490
|
+
else:
|
|
491
|
+
per_token_logps = []
|
|
492
|
+
for row_logits, row_labels in zip(logits, index):
|
|
493
|
+
row_logps = F.log_softmax(row_logits, dim=-1)
|
|
494
|
+
row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)
|
|
495
|
+
per_token_logps.append(row_per_token_logps)
|
|
496
|
+
per_token_logps = torch.stack(per_token_logps)
|
|
497
|
+
return per_token_logps
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
def _mask_prompt(labels):
|
|
501
|
+
"""
|
|
502
|
+
Mask 掉 Prompt 部分以及固定的模版标签,只保留模型需要生成的真正内容。
|
|
503
|
+
策略:
|
|
504
|
+
1. <system>/<user> 到 </s> 之间:全部 Mask。
|
|
505
|
+
2. </s> 后的 <assistant>:Mask
|
|
506
|
+
3. <assistant> 后的 <answer>:Mask
|
|
507
|
+
4. <assistant> 后的 <think>:保留
|
|
508
|
+
例如:
|
|
509
|
+
1. 原始: <system>system</s><user>user</s><assistant>content</s>
|
|
510
|
+
mask: mask mask mask mask mask mask mask content</s>
|
|
511
|
+
2. 原始: <system>system</s><user>user</s><assistant><answer>content</answer></s>
|
|
512
|
+
mask: mask mask mask mask mask mask mask mask content</answer></s>
|
|
513
|
+
3. 原始:<system>system</s><user>user</s><assistant><think>think</think><answer>content</answer></s>
|
|
514
|
+
mask: mask mask mask mask mask mask mask <think>think</think><answer>content</answer></s>
|
|
515
|
+
"""
|
|
516
|
+
system_token_id = TrainerTools().tokenizer.system
|
|
517
|
+
user_token_id = TrainerTools().tokenizer.user
|
|
518
|
+
end_token_id = TrainerTools().tokenizer.end
|
|
519
|
+
assistant_token_id = TrainerTools().tokenizer.assistant
|
|
520
|
+
answer_start_token_id = TrainerTools().tokenizer.answer_start
|
|
521
|
+
think_start_token_id = TrainerTools().tokenizer.think_start
|
|
522
|
+
ignore_index = -100
|
|
523
|
+
|
|
524
|
+
for batch, label in enumerate(labels):
|
|
525
|
+
start_index = -1
|
|
526
|
+
seq_len = len(label)
|
|
527
|
+
|
|
528
|
+
for index, token in enumerate(label):
|
|
529
|
+
if token == system_token_id or token == user_token_id:
|
|
530
|
+
if start_index != -1:
|
|
531
|
+
labels[batch, start_index: index] = ignore_index
|
|
532
|
+
|
|
533
|
+
start_index = index
|
|
534
|
+
|
|
535
|
+
elif token == end_token_id and start_index != -1:
|
|
536
|
+
end_mask_index = index
|
|
537
|
+
next_idx = index + 1
|
|
538
|
+
|
|
539
|
+
if next_idx < seq_len and label[next_idx] == assistant_token_id:
|
|
540
|
+
end_mask_index = next_idx
|
|
541
|
+
|
|
542
|
+
after_assistant_idx = next_idx + 1
|
|
543
|
+
if after_assistant_idx < seq_len:
|
|
544
|
+
token_after = label[after_assistant_idx]
|
|
545
|
+
|
|
546
|
+
if token_after == answer_start_token_id:
|
|
547
|
+
end_mask_index = after_assistant_idx
|
|
548
|
+
elif token_after == think_start_token_id:
|
|
549
|
+
pass
|
|
550
|
+
|
|
551
|
+
labels[batch, start_index: end_mask_index + 1] = ignore_index
|
|
552
|
+
start_index = -1
|
|
553
|
+
|
|
554
|
+
# 循环结束后,如果 start_index 仍然不是 -1,说明序列被截断了(Truncation)。
|
|
555
|
+
# 此时整个尾部都属于 Prompt(例如 User 说到一半被截断),必须全部 Mask。
|
|
556
|
+
if start_index != -1:
|
|
557
|
+
labels[batch, start_index:] = ignore_index
|
|
558
|
+
|
|
559
|
+
return labels
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
def _zero_pad_sequences(
|
|
563
|
+
sequences: list[torch.Tensor], side: str = "left"
|
|
564
|
+
) -> torch.Tensor:
|
|
565
|
+
assert side in ("left", "right")
|
|
566
|
+
max_len = max(seq.size(0) for seq in sequences)
|
|
567
|
+
padded_sequences = []
|
|
568
|
+
for seq in sequences:
|
|
569
|
+
pad_len = max_len - seq.size(0)
|
|
570
|
+
padding = (pad_len, 0) if side == "left" else (0, pad_len)
|
|
571
|
+
padded_sequences.append(F.pad(seq, padding))
|
|
572
|
+
return torch.stack(padded_sequences, dim=0)
|
|
573
|
+
|
|
574
|
+
|
|
575
|
+
class RunningMeanStd(nn.Module):
|
|
576
|
+
def __init__(self, shape: Tuple[int, ...] = (), epsilon: float = 1e-5):
|
|
577
|
+
super().__init__()
|
|
578
|
+
self.shape = shape
|
|
579
|
+
self.epsilon = epsilon
|
|
580
|
+
|
|
581
|
+
self.register_buffer("mean", torch.zeros(shape, dtype=torch.float64))
|
|
582
|
+
self.register_buffer("var", torch.ones(shape, dtype=torch.float64))
|
|
583
|
+
self.register_buffer("count", torch.tensor(1e-4, dtype=torch.float64))
|
|
584
|
+
|
|
585
|
+
def update(self, x: torch.Tensor):
|
|
586
|
+
x = x.to(dtype=torch.float64)
|
|
587
|
+
|
|
588
|
+
batch_mean = x.mean(dim=0)
|
|
589
|
+
batch_var = x.var(dim=0, unbiased=False)
|
|
590
|
+
batch_count = torch.tensor(x.shape[0], device=x.device, dtype=torch.float64)
|
|
591
|
+
|
|
592
|
+
if TrainerTools().parallel.parallel_train:
|
|
593
|
+
dist.all_reduce(batch_count, op=dist.ReduceOp.SUM)
|
|
594
|
+
|
|
595
|
+
batch_sum = x.sum(dim=0)
|
|
596
|
+
dist.all_reduce(batch_sum, op=dist.ReduceOp.SUM)
|
|
597
|
+
|
|
598
|
+
batch_sum_sq = (x ** 2).sum(dim=0)
|
|
599
|
+
dist.all_reduce(batch_sum_sq, op=dist.ReduceOp.SUM)
|
|
600
|
+
|
|
601
|
+
batch_mean = batch_sum / batch_count
|
|
602
|
+
batch_mean_sq = batch_sum_sq / batch_count
|
|
603
|
+
batch_var = batch_mean_sq - batch_mean ** 2
|
|
604
|
+
|
|
605
|
+
batch_var = torch.clamp(batch_var, min=0.0)
|
|
606
|
+
|
|
607
|
+
delta = batch_mean - self.mean
|
|
608
|
+
tot_count = self.count + batch_count
|
|
609
|
+
|
|
610
|
+
new_mean = self.mean + delta * (batch_count / tot_count)
|
|
611
|
+
|
|
612
|
+
m_a = self.var * self.count
|
|
613
|
+
m_b = batch_var * batch_count
|
|
614
|
+
M2 = m_a + m_b + (delta ** 2) * (self.count * batch_count / tot_count)
|
|
615
|
+
|
|
616
|
+
new_var = M2 / tot_count
|
|
617
|
+
|
|
618
|
+
self.mean = new_mean
|
|
619
|
+
self.var = new_var
|
|
620
|
+
self.count = tot_count
|
|
621
|
+
|
|
622
|
+
def forward(self, x: torch.Tensor, shift_mean: bool = True) -> torch.Tensor:
|
|
623
|
+
target_dtype = x.dtype
|
|
624
|
+
mean = self.mean.to(target_dtype)
|
|
625
|
+
var = self.var.to(target_dtype)
|
|
626
|
+
|
|
627
|
+
if shift_mean:
|
|
628
|
+
return (x - mean) / torch.sqrt(var + self.epsilon)
|
|
629
|
+
else:
|
|
630
|
+
return x / torch.sqrt(var + self.epsilon)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
#!python
|
|
2
|
+
|
|
3
|
+
if __name__ == '__main__':
|
|
4
|
+
import sys
|
|
5
|
+
arguments = sys.argv[1:]
|
|
6
|
+
hidden_size = int(arguments[0])
|
|
7
|
+
if len(arguments) > 1:
|
|
8
|
+
multiple_of = int(arguments[1])
|
|
9
|
+
else:
|
|
10
|
+
multiple_of = 64
|
|
11
|
+
|
|
12
|
+
intermediate_size = 4 * hidden_size
|
|
13
|
+
intermediate_size = int(2 * intermediate_size / 3)
|
|
14
|
+
intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of)
|
|
15
|
+
print(f'intermediate_size={intermediate_size}')
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
#!python
|
|
2
|
+
|
|
3
|
+
if __name__ == '__main__':
|
|
4
|
+
import os, sys
|
|
5
|
+
arguments = sys.argv[1:]
|
|
6
|
+
# file_name
|
|
7
|
+
run_file_name = arguments[0]
|
|
8
|
+
|
|
9
|
+
extra_args = ''
|
|
10
|
+
if len(arguments) > 1:
|
|
11
|
+
extra_args = f"{' '.join(arguments[1:])} "
|
|
12
|
+
|
|
13
|
+
os.environ['PARALLEL_TYPE'] = 'ddp'
|
|
14
|
+
|
|
15
|
+
if len(extra_args) == 0:
|
|
16
|
+
extra_args = '--standalone --nproc_per_node=gpu '
|
|
17
|
+
|
|
18
|
+
command = f'torchrun {extra_args}{run_file_name}'
|
|
19
|
+
|
|
20
|
+
print(f'run command {command}')
|
|
21
|
+
os.system(command)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
#!python
|
|
2
|
+
|
|
3
|
+
if __name__ == '__main__':
|
|
4
|
+
import os, sys
|
|
5
|
+
arguments = sys.argv[1:]
|
|
6
|
+
# file_name
|
|
7
|
+
run_file_name = arguments[0]
|
|
8
|
+
|
|
9
|
+
extra_args = ''
|
|
10
|
+
if len(arguments) > 1:
|
|
11
|
+
extra_args = f"{' '.join(arguments[1:])} "
|
|
12
|
+
|
|
13
|
+
os.environ['PARALLEL_TYPE'] = 'ds'
|
|
14
|
+
command = f'deepspeed {extra_args}{run_file_name}'
|
|
15
|
+
|
|
16
|
+
print(f'run command {command}')
|
|
17
|
+
os.system(command)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
#!python
|
|
2
|
+
|
|
3
|
+
if __name__ == '__main__':
|
|
4
|
+
import os, sys
|
|
5
|
+
arguments = sys.argv[1:]
|
|
6
|
+
run_file_name = arguments[0]
|
|
7
|
+
|
|
8
|
+
os.environ['PARALLEL_TYPE'] = 'none'
|
|
9
|
+
command = f'python3 {run_file_name}'
|
|
10
|
+
|
|
11
|
+
print(f'real command is {command}')
|
|
12
|
+
os.system(command)
|