project-llm-trainer 0.12.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_trainer/trainer.py ADDED
@@ -0,0 +1,34 @@
1
+ from typing import List, Tuple
2
+
3
+ from torch.utils.data import Dataset
4
+
5
+ from .base_trainer import BaseTrainer
6
+ from .train_configs import TrainConfig
7
+ from .utils import pretrain_collate_fn
8
+ from .dataset import PretrainDataset
9
+
10
+
11
+ class Trainer(BaseTrainer):
12
+ def __init__(
13
+ self,
14
+ *,
15
+ train_config: TrainConfig,
16
+ eval_prompts: List[str],
17
+ ):
18
+ super().__init__(
19
+ train_config=train_config,
20
+ eval_prompts=eval_prompts,
21
+ kd_config=train_config.pretrain_config.kd_config,
22
+ gradient_accumulation_steps=train_config.pretrain_config.gradient_accumulation_steps
23
+ )
24
+
25
+ def _convert_train_args(self) -> Tuple[dict, dict, dict]:
26
+ parallel_kwargs, data_loader_kwargs, sampler_kwargs = super()._convert_train_args()
27
+ data_loader_kwargs.update({"collate_fn": pretrain_collate_fn})
28
+
29
+ return parallel_kwargs, data_loader_kwargs, sampler_kwargs
30
+
31
+ def _create_dataset(self, file_idx) -> Tuple[Dataset, str]:
32
+ file_path = self.train_config.file_dataset[file_idx]
33
+ max_seq_len = self.train_config.max_seq_len
34
+ return PretrainDataset(file_path, max_seq_len, max_seq_len), file_path
llm_trainer/utils.py ADDED
@@ -0,0 +1,547 @@
1
+ import random
2
+ from contextlib import nullcontext
3
+ import torch
4
+ from torch.nn.utils.rnn import pad_sequence
5
+ import torch.nn.functional as F
6
+ from .tools import TrainerTools
7
+ import numpy as np
8
+ from typing import Union, List, Optional
9
+
10
+
11
+ def set_seed(seed=42):
12
+ random.seed(seed)
13
+ np.random.seed(seed)
14
+ torch.manual_seed(seed)
15
+ torch.cuda.manual_seed(seed)
16
+ torch.cuda.manual_seed_all(seed)
17
+
18
+
19
+ def autocast(device_type):
20
+ if TrainerTools().use_amp:
21
+ dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
22
+ return torch.autocast(
23
+ device_type=device_type,
24
+ dtype=dtype,
25
+ enabled=True,
26
+ cache_enabled=None
27
+ )
28
+ else:
29
+ return nullcontext()
30
+
31
+
32
+
33
+ def create_doc_boundary_mask(
34
+ input_ids: torch.Tensor,
35
+ dtype: torch.dtype
36
+ ) -> torch.Tensor:
37
+ """
38
+ 根据文档结束符 (eot) 的位置,创建一个 attention mask 来阻止跨文档的注意力。
39
+
40
+ 这个函数生成的 mask 会阻止一个 token 关注 (attend to) 属于前面文档的 tokens。
41
+ 例如,对于输入 `[[1, 2, eot, 3, 4, eot]]`,
42
+ tokens `3` 和 `4` 将无法关注 `1`, `2`, 和第一个 `eot`。
43
+
44
+ Args:
45
+ input_ids (torch.Tensor): 输入的 token ID 张量,形状为 (bsz, seq_len)。
46
+ dtype (torch.dtype): 数据类型。
47
+
48
+ Returns:
49
+ torch.Tensor: 符合 attention 机制要求的 mask 张量,
50
+ 形状为 (bsz, 1, seq_len, seq_len)。
51
+ 值为 -inf 的位置表示被屏蔽,值为 0 的位置表示允许注意力。
52
+ """
53
+ # 获取 batch size 和 sequence length
54
+ bsz, seq_len = input_ids.shape
55
+
56
+ # 1. 确定每个 eot_token 的位置
57
+ # is_eot 是一个布尔张量,形状为 (bsz, seq_len)
58
+ is_eot = (input_ids == TrainerTools().tokenizer.end)
59
+
60
+ # 2. 为每个 token 分配一个文档 ID
61
+ # 我们使用 cumsum (累加和) 来创建递增的文档 ID。一个 token 所属的文档 ID,
62
+ # 取决于它前面有多少个 eot。
63
+ # 示例:
64
+ # input_ids: [[1, 2, 3, eot, 4, 5, eot]]
65
+ # is_eot: [F, F, F, T, F, F, T] -> [0, 0, 0, 1, 0, 0, 1]
66
+ # doc_ids_ending: [0, 0, 0, 1, 1, 1, 2] (cumsum 的结果)
67
+ # doc_ids: [0, 0, 0, 0, 1, 1, 1] (向右移位后的结果)
68
+ # 这个结果正确地将文档 0 分配给了前四个 token,将文档 1 分配给了后三个 token。
69
+ doc_ids_ending = torch.cumsum(is_eot, dim=-1)
70
+ doc_ids = F.pad(doc_ids_ending[:, :-1], (1, 0), value=0)
71
+
72
+ # 3. 通过比较 query 和 key 的文档 ID 来创建 mask
73
+ # 我们的目标是:当 query token 所在的文档 ID 大于 key token 所在的文档 ID 时,进行屏蔽。
74
+ # query_doc_ids 形状: (bsz, seq_len, 1)
75
+ # key_doc_ids 形状: (bsz, 1, seq_len)
76
+ query_doc_ids = doc_ids.unsqueeze(2)
77
+ key_doc_ids = doc_ids.unsqueeze(1)
78
+
79
+ # 利用 PyTorch 的广播机制,`query_doc_ids > key_doc_ids` 会创建一个
80
+ # 形状为 (bsz, seq_len, seq_len) 的布尔张量。
81
+ # 当 query 的文档 ID 大于 key 的文档 ID 时,值为 True,这正是我们需要屏蔽的位置。
82
+ boundary_mask = query_doc_ids > key_doc_ids
83
+
84
+ # 4. 将布尔 mask 转换为 attention 机制所需的浮点数 mask (-inf 和 0)
85
+ final_mask = torch.zeros(
86
+ (bsz, seq_len, seq_len), device=input_ids.device, dtype=dtype
87
+ )
88
+ final_mask.masked_fill_(boundary_mask, torch.finfo(dtype).min)
89
+
90
+ # 5. 增加一个维度以匹配 attention head 的输入要求 (bsz, num_heads, seq_len, seq_len)
91
+ # 这里我们只生成一个 mask,它可以被广播到所有的 head。
92
+ return final_mask.unsqueeze(1)
93
+
94
+
95
+ def generate_position_ids(input_ids: torch.Tensor):
96
+ """
97
+ 为打包序列生成 position_ids 张量。
98
+
99
+ 参数:
100
+ input_ids (torch.Tensor): 输入的 token ID 张量 (batch_size, sequence_length)。
101
+ end_of_text_id (int): 代表文本结束的特殊 token ID。
102
+
103
+ 返回:
104
+ torch.Tensor: 生成的 position_ids 张量。
105
+ """
106
+ # 获取输入张量的形状
107
+ batch_size, seq_length = input_ids.shape
108
+
109
+ # 创建一个与输入形状相同,全为0的张量来存储position_ids
110
+ # 第一个token的位置永远是0,所以这个初始化是正确的
111
+ position_ids = torch.zeros_like(input_ids, dtype=torch.long)
112
+
113
+ # 从第二个时间步 (t=1) 开始遍历整个序列
114
+ for t in range(1, seq_length):
115
+ # 检查前一个时间步 (t-1) 的token是否为 EOT token
116
+ # 这会为批次中的每个序列生成一个布尔值
117
+ is_reset_token = (input_ids[:, t - 1] == TrainerTools().tokenizer.end)
118
+
119
+ # 获取前一个时间步的位置ID
120
+ prev_position_ids = position_ids[:, t - 1]
121
+
122
+ # 如果前一个token是EOT,当前位置重置为0;否则,在前一个位置上加1
123
+ # torch.where 会根据 is_reset_token 的布尔值进行选择
124
+ position_ids[:, t] = torch.where(is_reset_token, 0, prev_position_ids + 1)
125
+
126
+ return position_ids
127
+
128
+
129
+ def calc_position_ids(attention_mask: torch.Tensor) -> torch.Tensor:
130
+ """
131
+ 根据 attention_mask 计算 position_ids,主要用于 Left Padding 场景。
132
+ mask: [0, 0, 1, 1, 1] -> position_ids: [0, 0, 0, 1, 2]
133
+ """
134
+ position_ids = attention_mask.long().cumsum(-1) - 1
135
+ position_ids.masked_fill_(attention_mask == 0, 0)
136
+ return position_ids
137
+
138
+
139
+ def repeat_image_tok(
140
+ tokens: torch.Tensor,
141
+ tokens_per_image: int
142
+ ) -> torch.Tensor:
143
+ # tokens_per_image=3 -> <image>...xxxx -> <image><image><image>...xxx
144
+ image_tok = TrainerTools().tokenizer.image
145
+ mask = (tokens == image_tok)
146
+ if not mask.any():
147
+ return tokens
148
+
149
+ idxs = torch.nonzero(mask, as_tuple=False)
150
+ if idxs.numel() == 0:
151
+ return tokens
152
+
153
+ image_tok_idx = idxs[0, 0].item()
154
+ repeat_image_toks = torch.tensor([image_tok] * tokens_per_image, dtype=tokens.dtype, device=tokens.device)
155
+ new_tokens = torch.cat([tokens[:image_tok_idx], repeat_image_toks, tokens[image_tok_idx + 1:]], dim=-1)
156
+ return new_tokens
157
+
158
+
159
+ def batch_repeat_image_tok(
160
+ tokens: torch.Tensor,
161
+ tokens_per_image: int
162
+ ) -> torch.Tensor:
163
+ new_tokens = []
164
+
165
+ for token in tokens:
166
+ new_tokens.append(repeat_image_tok(token, tokens_per_image))
167
+
168
+ return torch.stack(new_tokens, dim=0)
169
+
170
+
171
+ def pretrain_collate_fn(batch_data):
172
+ # [[x,x,x], [y,y,y]]
173
+ inputs = pad_sequence(batch_data, batch_first=True, padding_value=TrainerTools().tokenizer.pad)
174
+ # crossEntropy默认的ignore_index是-100
175
+ labels = pad_sequence(batch_data, batch_first=True, padding_value=-100)
176
+
177
+ # inputs, labels
178
+ return {
179
+ 'inputs': inputs,
180
+ 'labels': labels
181
+ }
182
+
183
+
184
+ def get_sft_collate_fn(mask_prompt: bool):
185
+ def sft_collate_fn(batch_data):
186
+ """
187
+ 如果是sft,则不计算prompt部分的loss, 例如:
188
+ logits: [USER]你好[BOT]我好[SEP]
189
+ labels: [USER]你好[BOT]我好[SEP]
190
+
191
+ shift_logits: [USER]你好[BOT]我好
192
+ shift_labels: 你好[BOT]我好[SEP]
193
+
194
+ mask_labels: mask mask mask mask 我好[SEP]
195
+ * mask=-100和pad一样
196
+
197
+
198
+ 多轮对话场景
199
+ [USER]你好[BOT]我好[SEP][USER]很好[BOT]不好[SEP]
200
+ mask: mask mask mask mask 我好[SEP] mask mask mask mask 不好[SEP]
201
+ """
202
+ batch_train_data = []
203
+ image_tags = []
204
+ for item in batch_data:
205
+ batch_train_data.append(item['inputs'])
206
+ image_tags.append(item['image_tag'])
207
+
208
+ # [[x,x,x], [y,y,y]]
209
+ inputs = pad_sequence(batch_train_data, batch_first=True, padding_value=TrainerTools().tokenizer.pad)
210
+ # crossEntropy默认的ignore_index是-100
211
+ labels = pad_sequence(batch_train_data, batch_first=True, padding_value=-100)
212
+
213
+ if mask_prompt:
214
+ labels = _mask_prompt(labels)
215
+
216
+ return {
217
+ 'inputs': inputs,
218
+ 'labels': labels,
219
+ 'image_tags': image_tags
220
+ }
221
+
222
+ return sft_collate_fn
223
+
224
+
225
+ def get_dpo_collate_fn(mask_prompt: bool):
226
+ def dpo_collate_fn(batch_data):
227
+ # batch_data: [{'chosen': chosen, 'rejected': rejected}, {'chosen': chosen, 'rejected': rejected}]
228
+ chosen_inputs = []
229
+ chosen_labels = []
230
+ rejected_inputs = []
231
+ rejected_labels = []
232
+
233
+ max_len = 0
234
+ for key in ['chosen', 'rejected']:
235
+ max_len = max(max(len(item[key]) for item in batch_data), max_len)
236
+
237
+ for item in batch_data:
238
+ chosen_sequence = item['chosen']
239
+ chosen_inputs.append(chosen_sequence + [TrainerTools().tokenizer.pad] * (max_len - len(chosen_sequence)))
240
+ chosen_labels.append(chosen_sequence + [-100] * (max_len - len(chosen_sequence)))
241
+
242
+ rejected_sequence = item['rejected']
243
+ rejected_inputs.append(rejected_sequence + [TrainerTools().tokenizer.pad] * (max_len - len(rejected_sequence)))
244
+ rejected_labels.append(rejected_sequence + [-100] * (max_len - len(rejected_sequence)))
245
+
246
+ chosen_inputs = torch.tensor(chosen_inputs).long()
247
+ chosen_labels = torch.tensor(chosen_labels).long()
248
+ if mask_prompt:
249
+ chosen_labels = _mask_prompt(chosen_labels)
250
+
251
+ rejected_inputs = torch.tensor(rejected_inputs).long()
252
+ rejected_labels = torch.tensor(rejected_labels).long()
253
+ if mask_prompt:
254
+ rejected_labels = _mask_prompt(rejected_labels)
255
+
256
+ return {
257
+ 'chosen_inputs': chosen_inputs,
258
+ 'chosen_labels': chosen_labels,
259
+ 'rejected_inputs': rejected_inputs,
260
+ 'rejected_labels': rejected_labels
261
+ }
262
+
263
+ return dpo_collate_fn
264
+
265
+
266
+ def split_batch(data_per_batch: dict) -> list[dict]:
267
+ """
268
+ from: data_per_batch("sequences": [group_size, max_generate_len] ...)
269
+ to: [dict("sequences": [max_generate_len] ...) ... group_size]
270
+ """
271
+
272
+ group_size = data_per_batch['sequence_ids'].size(0)
273
+ # [{"sequence_ids": xxx, "old_log_probs": xxx...}, ...]
274
+ group_data = [{} for _ in range(group_size)]
275
+
276
+ keys = (
277
+ 'sequence_ids',
278
+ 'old_log_probs',
279
+ 'ref_log_probs',
280
+ 'advantages',
281
+ 'attention_mask',
282
+ 'mask',
283
+ )
284
+
285
+ for key in keys:
286
+ value = data_per_batch[key]
287
+ if value is None:
288
+ vals = [None] * group_size
289
+ else:
290
+ vals = torch.unbind(value)
291
+
292
+ for i, v in enumerate(vals):
293
+ group_data[i][key] = v
294
+
295
+ return group_data
296
+
297
+
298
+ def join_batch(batch_data: list[dict]) -> dict:
299
+ """
300
+ from: [dict("sequences": [max_generate_len] ...), ...]
301
+ to: dict("sequences": max_generate_len, ...)
302
+ """
303
+
304
+ result = {}
305
+ keys = (
306
+ 'sequence_ids',
307
+ 'old_log_probs',
308
+ 'ref_log_probs',
309
+ 'advantages',
310
+ 'attention_mask',
311
+ 'mask',
312
+ )
313
+
314
+ for key in keys:
315
+ # [sequence_ids, sequence_ids ...]
316
+ # shape [batch_size, seq_len]
317
+ vals = [item[key] for item in batch_data]
318
+ if all(v is not None for v in vals):
319
+ data = _zero_pad_sequences(vals, "left")
320
+ else:
321
+ data = None
322
+ result[key] = data
323
+
324
+ return result
325
+
326
+
327
+ # 默认使用torch提供的pad_sequence
328
+ # 如果pad_sequence不支持padding_side参数,则将改参数置为False,使用反转的方式
329
+ _use_origin_pad_sequence = True
330
+ def left_pad_sequence(
331
+ sequences: Union[torch.Tensor, List[torch.Tensor]],
332
+ padding_value: float,
333
+ ) -> torch.Tensor:
334
+ global _use_origin_pad_sequence
335
+
336
+ if _use_origin_pad_sequence:
337
+ try:
338
+ return pad_sequence(sequences, batch_first=True, padding_value=padding_value, padding_side='left')
339
+ except TypeError:
340
+ _use_origin_pad_sequence = False
341
+ return left_pad_sequence(sequences, padding_value)
342
+ else:
343
+ # 反转每个序列的顺序(如 [1,2,3] → [3,2,1])
344
+ reversed_sequences = [seq.flip(dims=(0,)) for seq in sequences]
345
+ # 使用默认的右侧填充
346
+ padded_reversed = pad_sequence(reversed_sequences, batch_first=True, padding_value=padding_value)
347
+ # 再次反转序列顺序,恢复原始方向(填充在左侧)
348
+ return padded_reversed.flip(dims=(1,))
349
+
350
+
351
+ _use_memory_efficient_log_softmax = True
352
+ def log_softmax(logits, index) -> torch.Tensor:
353
+ if _use_memory_efficient_log_softmax:
354
+ return _selective_log_softmax(logits, index)
355
+
356
+ # Convert raw logits into log probabilities along the vocabulary axis.
357
+ # [batch_size, seq_len, vocab_size]
358
+ log_probs = F.log_softmax(logits, dim=-1)
359
+
360
+ # Reshape input_ids from (batch_size, seq_len) to (batch_size, seq_len, 1) for gathering.
361
+ # Then, gather the log probability for each token in input_ids.
362
+ selected_log_probs = log_probs.gather(dim=-1, index=index.unsqueeze(-1))
363
+
364
+ # Remove the extra last dimension to get back to shape (batch_size, seq_len).
365
+ return selected_log_probs.squeeze(-1)
366
+
367
+
368
+ def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True) -> torch.Tensor:
369
+ """Whiten values with masked values."""
370
+ mean, var = _masked_mean(values, mask), _masked_var(values, mask)
371
+ whitened = (values - mean) * torch.rsqrt(var + 1e-8)
372
+ if not shift_mean:
373
+ whitened += mean
374
+ return whitened
375
+
376
+
377
+ def truncate_sequences_at_eos(
378
+ sequences: torch.Tensor,
379
+ eos_token_id: int,
380
+ pad_token_id: int
381
+ ) -> torch.Tensor:
382
+ """
383
+ 高效地将批处理中的序列在第一个EOS标记处截断。
384
+ 第一个EOS标记之后的所有内容(不包括EOS自身)将被替换为pad_token_id。
385
+
386
+ 这是一个向量化的实现,以确保在GPU上的性能。
387
+ 它使用 torch.where,因此不依赖于 pad_token_id 必须为0。
388
+
389
+ Args:
390
+ sequences (torch.Tensor): 批处理序列, 形状为 (batch_size, seq_len)。
391
+ eos_token_id (int): 句子结束标记的ID。
392
+ pad_token_id (int): 填充标记的ID。
393
+
394
+ Returns:
395
+ torch.Tensor: 截断后的序列,形状与输入相同。
396
+ """
397
+ # 创建一个布尔掩码,标记所有EOS token的位置
398
+ eos_mask = (sequences == eos_token_id)
399
+
400
+ # 找到每行中第一个True(即第一个EOS token)的索引
401
+ # .int() 是为了兼容旧版torch,argmax需要非布尔类型
402
+ first_eos_indices = torch.argmax(eos_mask.int(), dim=1)
403
+
404
+ # 检查哪些序列确实包含了EOS token。
405
+ # 如果某一行完全没有EOS, argmax会返回0, 这会产生歧义。
406
+ has_eos = eos_mask.any(dim=1)
407
+
408
+ # 对于没有EOS token的序列,将截断索引设置为序列最大长度,以防错误截断
409
+ first_eos_indices[~has_eos] = sequences.shape[1]
410
+
411
+ # 创建一个 [0, 1, 2, ..., seq_len-1] 的索引张量
412
+ indices_mask = torch.arange(sequences.shape[1], device=sequences.device)
413
+
414
+ # 利用广播机制创建一个掩码,标记所有应保留的token
415
+ # 对于每个序列,当 token_index < first_eos_index 时为True
416
+ keep_mask = indices_mask < first_eos_indices.unsqueeze(1)
417
+
418
+ # 使用 torch.where 进行安全替换
419
+ # 如果 keep_mask 为 True,则保留原始序列的token,否则替换为 pad_token_id
420
+ truncated_sequences = torch.where(
421
+ keep_mask,
422
+ sequences,
423
+ pad_token_id
424
+ )
425
+
426
+ return truncated_sequences
427
+
428
+
429
+ def disable_dropout_in_model(model: torch.nn.Module) -> None:
430
+ for module in model.modules():
431
+ if isinstance(module, torch.nn.Dropout):
432
+ module.p = 0
433
+
434
+
435
+ def _masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor:
436
+ """Compute mean of tensor with a masked values."""
437
+ if axis is not None:
438
+ return (values * mask).sum(axis=axis) / mask.sum(axis=axis)
439
+ else:
440
+ return (values * mask).sum() / mask.sum()
441
+
442
+
443
+ def _masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor:
444
+ """Compute variance of tensor with masked values."""
445
+ mean = _masked_mean(values, mask)
446
+ centered_values = values - mean
447
+ variance = _masked_mean(centered_values**2, mask)
448
+ if unbiased:
449
+ mask_sum = mask.sum()
450
+ if mask_sum == 0:
451
+ return torch.tensor(0.0, device=values.device, dtype=values.dtype)
452
+
453
+ # note that if mask_sum == 1, then there is a division by zero issue
454
+ # to avoid it you just need to use a larger minibatch_size
455
+ bessel_correction = mask_sum / (mask_sum - 1)
456
+ variance = variance * bessel_correction
457
+ return variance
458
+
459
+
460
+ def _selective_log_softmax(logits, index) -> torch.Tensor:
461
+ if logits.dtype in [torch.float32, torch.float64]:
462
+ selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
463
+ logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
464
+ per_token_logps = selected_logits - logsumexp_values
465
+ else:
466
+ per_token_logps = []
467
+ for row_logits, row_labels in zip(logits, index):
468
+ row_logps = F.log_softmax(row_logits, dim=-1)
469
+ row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)
470
+ per_token_logps.append(row_per_token_logps)
471
+ per_token_logps = torch.stack(per_token_logps)
472
+ return per_token_logps
473
+
474
+
475
+ def _mask_prompt(labels):
476
+ """
477
+ Mask 掉 Prompt 部分以及固定的模版标签,只保留模型需要生成的真正内容。
478
+ 策略:
479
+ 1. <system>/<user> 到 </s> 之间:全部 Mask。
480
+ 2. </s> 后的 <assistant>:Mask
481
+ 3. <assistant> 后的 <answer>:Mask
482
+ 4. <assistant> 后的 <think>:保留
483
+ 例如:
484
+ 1. 原始: <system>system</s><user>user</s><assistant>content</s>
485
+ mask: mask mask mask mask mask mask mask content</s>
486
+ 2. 原始: <system>system</s><user>user</s><assistant><answer>content</answer></s>
487
+ mask: mask mask mask mask mask mask mask mask content</answer></s>
488
+ 3. 原始:<system>system</s><user>user</s><assistant><think>think</think><answer>content</answer></s>
489
+ mask: mask mask mask mask mask mask mask <think>think</think><answer>content</answer></s>
490
+ """
491
+ system_token_id = TrainerTools().tokenizer.system
492
+ user_token_id = TrainerTools().tokenizer.user
493
+ end_token_id = TrainerTools().tokenizer.end
494
+ assistant_token_id = TrainerTools().tokenizer.assistant
495
+ answer_start_token_id = TrainerTools().tokenizer.answer_start
496
+ think_start_token_id = TrainerTools().tokenizer.think_start
497
+ ignore_index = -100
498
+
499
+ for batch, label in enumerate(labels):
500
+ start_index = -1
501
+ seq_len = len(label)
502
+
503
+ for index, token in enumerate(label):
504
+ if token == system_token_id or token == user_token_id:
505
+ if start_index != -1:
506
+ labels[batch, start_index: index] = ignore_index
507
+
508
+ start_index = index
509
+
510
+ elif token == end_token_id and start_index != -1:
511
+ end_mask_index = index
512
+ next_idx = index + 1
513
+
514
+ if next_idx < seq_len and label[next_idx] == assistant_token_id:
515
+ end_mask_index = next_idx
516
+
517
+ after_assistant_idx = next_idx + 1
518
+ if after_assistant_idx < seq_len:
519
+ token_after = label[after_assistant_idx]
520
+
521
+ if token_after == answer_start_token_id:
522
+ end_mask_index = after_assistant_idx
523
+ elif token_after == think_start_token_id:
524
+ pass
525
+
526
+ labels[batch, start_index: end_mask_index + 1] = ignore_index
527
+ start_index = -1
528
+
529
+ # 循环结束后,如果 start_index 仍然不是 -1,说明序列被截断了(Truncation)。
530
+ # 此时整个尾部都属于 Prompt(例如 User 说到一半被截断),必须全部 Mask。
531
+ if start_index != -1:
532
+ labels[batch, start_index:] = ignore_index
533
+
534
+ return labels
535
+
536
+
537
+ def _zero_pad_sequences(
538
+ sequences: list[torch.Tensor], side: str = "left"
539
+ ) -> torch.Tensor:
540
+ assert side in ("left", "right")
541
+ max_len = max(seq.size(0) for seq in sequences)
542
+ padded_sequences = []
543
+ for seq in sequences:
544
+ pad_len = max_len - seq.size(0)
545
+ padding = (pad_len, 0) if side == "left" else (0, pad_len)
546
+ padded_sequences.append(F.pad(seq, padding))
547
+ return torch.stack(padded_sequences, dim=0)
@@ -0,0 +1,15 @@
1
+ #!python
2
+
3
+ if __name__ == '__main__':
4
+ import sys
5
+ arguments = sys.argv[1:]
6
+ hidden_size = int(arguments[0])
7
+ if len(arguments) > 1:
8
+ multiple_of = int(arguments[1])
9
+ else:
10
+ multiple_of = 64
11
+
12
+ intermediate_size = 4 * hidden_size
13
+ intermediate_size = int(2 * intermediate_size / 3)
14
+ intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of)
15
+ print(f'intermediate_size={intermediate_size}')
@@ -0,0 +1,21 @@
1
+ #!python
2
+
3
+ if __name__ == '__main__':
4
+ import os, sys
5
+ arguments = sys.argv[1:]
6
+ # file_name
7
+ run_file_name = arguments[0]
8
+
9
+ extra_args = ''
10
+ if len(arguments) > 1:
11
+ extra_args = f"{' '.join(arguments[1:])} "
12
+
13
+ os.environ['PARALLEL_TYPE'] = 'ddp'
14
+
15
+ if len(extra_args) == 0:
16
+ extra_args = '--standalone --nproc_per_node=gpu '
17
+
18
+ command = f'torchrun {extra_args}{run_file_name}'
19
+
20
+ print(f'run command {command}')
21
+ os.system(command)
@@ -0,0 +1,17 @@
1
+ #!python
2
+
3
+ if __name__ == '__main__':
4
+ import os, sys
5
+ arguments = sys.argv[1:]
6
+ # file_name
7
+ run_file_name = arguments[0]
8
+
9
+ extra_args = ''
10
+ if len(arguments) > 1:
11
+ extra_args = f"{' '.join(arguments[1:])} "
12
+
13
+ os.environ['PARALLEL_TYPE'] = 'ds'
14
+ command = f'deepspeed {extra_args}{run_file_name}'
15
+
16
+ print(f'run command {command}')
17
+ os.system(command)