project-llm-trainer 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/dpo_trainer.py +6 -4
- llm_trainer/generate_utils.py +6 -15
- llm_trainer/grpo_trainer.py +3 -2
- llm_trainer/log.py +1 -0
- llm_trainer/parallel.py +3 -0
- llm_trainer/sft_trainer.py +1 -0
- llm_trainer/tools.py +1 -9
- llm_trainer/trainer.py +34 -11
- llm_trainer/utils.py +148 -38
- {project_llm_trainer-0.6.0.dist-info → project_llm_trainer-0.7.0.dist-info}/METADATA +1 -1
- project_llm_trainer-0.7.0.dist-info/RECORD +33 -0
- project_llm_trainer-0.6.0.dist-info/RECORD +0 -33
- {project_llm_trainer-0.6.0.data → project_llm_trainer-0.7.0.data}/scripts/calc_intermediate_size +0 -0
- {project_llm_trainer-0.6.0.data → project_llm_trainer-0.7.0.data}/scripts/ddp_train +0 -0
- {project_llm_trainer-0.6.0.data → project_llm_trainer-0.7.0.data}/scripts/ds_train +0 -0
- {project_llm_trainer-0.6.0.data → project_llm_trainer-0.7.0.data}/scripts/plot_loss +0 -0
- {project_llm_trainer-0.6.0.data → project_llm_trainer-0.7.0.data}/scripts/plot_lr +0 -0
- {project_llm_trainer-0.6.0.data → project_llm_trainer-0.7.0.data}/scripts/py_train +0 -0
- {project_llm_trainer-0.6.0.data → project_llm_trainer-0.7.0.data}/scripts/smart_train +0 -0
- {project_llm_trainer-0.6.0.dist-info → project_llm_trainer-0.7.0.dist-info}/WHEEL +0 -0
- {project_llm_trainer-0.6.0.dist-info → project_llm_trainer-0.7.0.dist-info}/top_level.txt +0 -0
llm_trainer/dpo_trainer.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import time
|
|
2
1
|
from typing import Tuple, List, Optional
|
|
3
2
|
import torch
|
|
4
3
|
from torch.utils.data import Dataset
|
|
@@ -11,7 +10,10 @@ from .train_configs import TrainConfig
|
|
|
11
10
|
from .dataset import DPODataset
|
|
12
11
|
from .loss import DPOLoss
|
|
13
12
|
from .tools import TrainerTools
|
|
14
|
-
from .utils import
|
|
13
|
+
from .utils import (
|
|
14
|
+
autocastcontext,
|
|
15
|
+
get_dpo_collate_fn
|
|
16
|
+
)
|
|
15
17
|
from .partition_utils import sync_model_params
|
|
16
18
|
|
|
17
19
|
from .checkpoint import (
|
|
@@ -34,7 +36,7 @@ class DPOTrainer(Trainer):
|
|
|
34
36
|
eval_prompts=eval_prompts,
|
|
35
37
|
eval_image_tags=eval_image_tags
|
|
36
38
|
)
|
|
37
|
-
|
|
39
|
+
self.packed_sequences = False
|
|
38
40
|
self.ref_model = self._init_ref_model()
|
|
39
41
|
|
|
40
42
|
def _init_ref_model(self):
|
|
@@ -201,7 +203,7 @@ class DPOTrainer(Trainer):
|
|
|
201
203
|
if TrainerTools().parallel.parallel_train:
|
|
202
204
|
self.train_model.require_backward_grad_sync = need_update_grad
|
|
203
205
|
|
|
204
|
-
with
|
|
206
|
+
with autocastcontext(TrainerTools().parallel.device_type):
|
|
205
207
|
policy_outputs = self.train_model(concat_inputs, attention_mask=concat_mask)
|
|
206
208
|
policy_probs = self._logprobs(policy_outputs['logits'], concat_labels, concat_mask)
|
|
207
209
|
aux_loss = policy_outputs.get('aux_loss')
|
llm_trainer/generate_utils.py
CHANGED
|
@@ -3,7 +3,10 @@ from contextlib import nullcontext
|
|
|
3
3
|
import torch
|
|
4
4
|
from llm_model import VlmModel, KVCache
|
|
5
5
|
from .tools import TrainerTools
|
|
6
|
-
from .utils import
|
|
6
|
+
from .utils import (
|
|
7
|
+
autocastcontext,
|
|
8
|
+
batch_repeat_image_tok
|
|
9
|
+
)
|
|
7
10
|
|
|
8
11
|
|
|
9
12
|
def _suppress_warper(logits: torch.Tensor, suppress_tokens: List[int]) -> torch.Tensor:
|
|
@@ -124,13 +127,7 @@ def _generate(
|
|
|
124
127
|
如果temperature很大但内容单一,需要增大k、p
|
|
125
128
|
"""
|
|
126
129
|
use_kv_cache = True
|
|
127
|
-
|
|
128
|
-
ctx = torch.autocast(
|
|
129
|
-
device_type=device,
|
|
130
|
-
dtype=TrainerTools().dtype,
|
|
131
|
-
enabled=True,
|
|
132
|
-
cache_enabled=None
|
|
133
|
-
) if TrainerTools().use_amp else nullcontext()
|
|
130
|
+
ctx = autocastcontext(device)
|
|
134
131
|
|
|
135
132
|
if isinstance(model, VlmModel):
|
|
136
133
|
tokens = batch_repeat_image_tok(tokens, tokens_per_image)
|
|
@@ -330,13 +327,7 @@ def batch_generate(
|
|
|
330
327
|
device: Union[str, torch.device, int]
|
|
331
328
|
):
|
|
332
329
|
use_kv_cache = True
|
|
333
|
-
|
|
334
|
-
ctx = torch.autocast(
|
|
335
|
-
device_type=device,
|
|
336
|
-
dtype=TrainerTools().dtype,
|
|
337
|
-
enabled=True,
|
|
338
|
-
cache_enabled=None
|
|
339
|
-
) if TrainerTools().use_amp else nullcontext()
|
|
330
|
+
ctx = autocastcontext(device)
|
|
340
331
|
|
|
341
332
|
if isinstance(model, VlmModel):
|
|
342
333
|
tokens = batch_repeat_image_tok(tokens, tokens_per_image)
|
llm_trainer/grpo_trainer.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import time
|
|
2
1
|
from typing import Tuple, List, Union, Callable, Optional
|
|
3
2
|
import torch
|
|
4
3
|
from torch.utils.data import Dataset
|
|
@@ -14,6 +13,7 @@ from .loss import GRPOLoss
|
|
|
14
13
|
from .tools import TrainerTools
|
|
15
14
|
from .generate_utils import batch_generate
|
|
16
15
|
from .log import log
|
|
16
|
+
from .utils import autocastcontext
|
|
17
17
|
|
|
18
18
|
from .partition_utils import (
|
|
19
19
|
sync_model_params,
|
|
@@ -41,6 +41,7 @@ class GRPOTrainer(Trainer):
|
|
|
41
41
|
eval_image_tags=eval_image_tags
|
|
42
42
|
)
|
|
43
43
|
|
|
44
|
+
self.packed_sequences = False
|
|
44
45
|
self.reward_func = reward_func
|
|
45
46
|
self.ref_model = self._init_ref_model()
|
|
46
47
|
|
|
@@ -341,7 +342,7 @@ class GRPOTrainer(Trainer):
|
|
|
341
342
|
log(f'start train for batch {batch}/{batch_count_per_file}')
|
|
342
343
|
|
|
343
344
|
for grpo_step in range(self.train_config.grpo_config.grpo_steps):
|
|
344
|
-
with
|
|
345
|
+
with autocastcontext(TrainerTools().parallel.device_type):
|
|
345
346
|
loss, aux_loss = self._maximize_grpo_objective(rollout_data)
|
|
346
347
|
if aux_loss_coef and aux_loss:
|
|
347
348
|
loss += aux_loss_coef * aux_loss
|
llm_trainer/log.py
CHANGED
llm_trainer/parallel.py
CHANGED
llm_trainer/sft_trainer.py
CHANGED
|
@@ -21,6 +21,7 @@ class SFTTrainer(Trainer):
|
|
|
21
21
|
eval_prompts=eval_prompts,
|
|
22
22
|
eval_image_tags=eval_image_tags
|
|
23
23
|
)
|
|
24
|
+
self.packed_sequences = False
|
|
24
25
|
|
|
25
26
|
def _convert_train_args(self) -> Tuple[dict, dict, dict, bool]:
|
|
26
27
|
sft_collate_fn = get_sft_collate_fn(self.train_config.mask_prompt)
|
llm_trainer/tools.py
CHANGED
|
@@ -31,15 +31,7 @@ class TrainerTools:
|
|
|
31
31
|
self.tokenizer = Tokenizer(os.environ.get('TOKENIZERS_TYPE', 'zh_llama'))
|
|
32
32
|
self.use_amp = 'cuda' in self.parallel.device and not isinstance(self.parallel, DsParallel)
|
|
33
33
|
|
|
34
|
-
|
|
35
|
-
self.dtype = dtypes[dtype] if dtype in dtypes else None
|
|
36
|
-
|
|
37
|
-
if not self.dtype:
|
|
38
|
-
self.dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
|
|
39
|
-
|
|
40
|
-
log(f'word_size={self.parallel.world_size},'
|
|
41
|
-
f' use_amp={self.use_amp},'
|
|
42
|
-
f' dtype={self.dtype}')
|
|
34
|
+
log(f'word_size={self.parallel.world_size}, use_amp={self.use_amp}')
|
|
43
35
|
|
|
44
36
|
def _new_parallel(self):
|
|
45
37
|
parallel_type = os.environ.get('PARALLEL_TYPE', 'none')
|
llm_trainer/trainer.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
from contextlib import nullcontext
|
|
2
1
|
from typing import Optional, Tuple, List, Dict, Any
|
|
3
2
|
import copy
|
|
4
3
|
|
|
@@ -37,6 +36,9 @@ from .checkpoint import (
|
|
|
37
36
|
|
|
38
37
|
from .utils import (
|
|
39
38
|
set_seed,
|
|
39
|
+
autocastcontext,
|
|
40
|
+
create_doc_boundary_mask,
|
|
41
|
+
generate_position_ids,
|
|
40
42
|
pretrain_collate_fn,
|
|
41
43
|
)
|
|
42
44
|
|
|
@@ -55,6 +57,17 @@ class Trainer:
|
|
|
55
57
|
):
|
|
56
58
|
set_seed()
|
|
57
59
|
|
|
60
|
+
# 是否打包序列,仅pretrain阶段需要打包序列,
|
|
61
|
+
# [[1, 1, eos, 2, 2, eos]]
|
|
62
|
+
# doc_boundary_mask=[[[[0., 0., 0., 0., 0., 0.],
|
|
63
|
+
# [0., 0., 0., 0., 0., 0.],
|
|
64
|
+
# [0., 0., 0., 0., 0., 0.],
|
|
65
|
+
# [-inf, -inf, -inf, 0., 0., 0.],
|
|
66
|
+
# [-inf, -inf, -inf, 0., 0., 0.],
|
|
67
|
+
# [-inf, -inf, -inf, 0., 0., 0.]]]]
|
|
68
|
+
# position_ids=[[0, 1, 2, 0, 1, 2]]
|
|
69
|
+
self.packed_sequences = True
|
|
70
|
+
|
|
58
71
|
self.train_config: TrainConfig = train_config
|
|
59
72
|
self.eval_prompts = eval_prompts
|
|
60
73
|
self.eval_image_tags = eval_image_tags
|
|
@@ -81,13 +94,6 @@ class Trainer:
|
|
|
81
94
|
|
|
82
95
|
self.criterion, self.kd_loss = self._init_loss()
|
|
83
96
|
|
|
84
|
-
self.ctx = torch.autocast(
|
|
85
|
-
device_type=TrainerTools().parallel.device_type,
|
|
86
|
-
dtype=TrainerTools().dtype,
|
|
87
|
-
enabled=True,
|
|
88
|
-
cache_enabled=None
|
|
89
|
-
) if TrainerTools().use_amp else nullcontext()
|
|
90
|
-
|
|
91
97
|
load_checkpoint(
|
|
92
98
|
self.train_model,
|
|
93
99
|
optimizer=self.optimizer,
|
|
@@ -433,6 +439,14 @@ class Trainer:
|
|
|
433
439
|
|
|
434
440
|
raise e
|
|
435
441
|
|
|
442
|
+
def _get_model_dtype(self):
|
|
443
|
+
if isinstance(TrainerTools().parallel, DsParallel):
|
|
444
|
+
import deepspeed
|
|
445
|
+
assert isinstance(self.train_model, deepspeed.DeepSpeedEngine)
|
|
446
|
+
return self.train_model.get_data_types()[0]
|
|
447
|
+
else:
|
|
448
|
+
return torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
|
|
449
|
+
|
|
436
450
|
def _eval(self, tag: str):
|
|
437
451
|
with unwrap_model_for_generation(self.train_model) as generate_model:
|
|
438
452
|
if TrainerTools().parallel.is_main_process:
|
|
@@ -526,8 +540,12 @@ class Trainer:
|
|
|
526
540
|
inputs, labels = inputs.to(TrainerTools().parallel.device), labels.to(TrainerTools().parallel.device)
|
|
527
541
|
attention_mask = inputs != TrainerTools().tokenizer.pad
|
|
528
542
|
|
|
529
|
-
if
|
|
530
|
-
|
|
543
|
+
if self.packed_sequences:
|
|
544
|
+
doc_boundary_mask = create_doc_boundary_mask(inputs, self._get_model_dtype())
|
|
545
|
+
position_ids = generate_position_ids(inputs)
|
|
546
|
+
else:
|
|
547
|
+
doc_boundary_mask = None
|
|
548
|
+
position_ids = None
|
|
531
549
|
|
|
532
550
|
if self.pixel_values_provider and 'image_tags' in batch_data:
|
|
533
551
|
image_tags = batch_data['image_tags']
|
|
@@ -535,10 +553,15 @@ class Trainer:
|
|
|
535
553
|
else:
|
|
536
554
|
pixel_values = None
|
|
537
555
|
|
|
538
|
-
|
|
556
|
+
if TrainerTools().parallel.parallel_train:
|
|
557
|
+
self.train_model.require_backward_grad_sync = need_update_grad
|
|
558
|
+
|
|
559
|
+
with autocastcontext(TrainerTools().parallel.device_type):
|
|
539
560
|
result = self.train_model(
|
|
540
561
|
inputs,
|
|
541
562
|
attention_mask=attention_mask,
|
|
563
|
+
doc_boundary_mask=doc_boundary_mask,
|
|
564
|
+
position_ids=position_ids,
|
|
542
565
|
pixel_values=pixel_values
|
|
543
566
|
)
|
|
544
567
|
|
llm_trainer/utils.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import random
|
|
2
|
+
from contextlib import nullcontext
|
|
2
3
|
import torch
|
|
3
4
|
from torch.nn.utils.rnn import pad_sequence
|
|
4
5
|
import torch.nn.functional as F
|
|
@@ -14,6 +15,115 @@ def set_seed(seed=42):
|
|
|
14
15
|
torch.cuda.manual_seed_all(seed)
|
|
15
16
|
|
|
16
17
|
|
|
18
|
+
def autocastcontext(device_type):
|
|
19
|
+
if TrainerTools().use_amp:
|
|
20
|
+
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
|
|
21
|
+
return torch.autocast(
|
|
22
|
+
device_type=device_type,
|
|
23
|
+
dtype=dtype,
|
|
24
|
+
enabled=True,
|
|
25
|
+
cache_enabled=None
|
|
26
|
+
)
|
|
27
|
+
else:
|
|
28
|
+
return nullcontext()
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def create_doc_boundary_mask(
|
|
32
|
+
input_ids: torch.Tensor,
|
|
33
|
+
dtype: torch.dtype
|
|
34
|
+
) -> torch.Tensor:
|
|
35
|
+
"""
|
|
36
|
+
根据文档结束符 (eot) 的位置,创建一个 attention mask 来阻止跨文档的注意力。
|
|
37
|
+
|
|
38
|
+
这个函数生成的 mask 会阻止一个 token 关注 (attend to) 属于前面文档的 tokens。
|
|
39
|
+
例如,对于输入 `[[1, 2, eot, 3, 4, eot]]`,
|
|
40
|
+
tokens `3` 和 `4` 将无法关注 `1`, `2`, 和第一个 `eot`。
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
input_ids (torch.Tensor): 输入的 token ID 张量,形状为 (bsz, seq_len)。
|
|
44
|
+
dtype (torch.dtype): 数据类型。
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
torch.Tensor: 符合 attention 机制要求的 mask 张量,
|
|
48
|
+
形状为 (bsz, 1, seq_len, seq_len)。
|
|
49
|
+
值为 -inf 的位置表示被屏蔽,值为 0 的位置表示允许注意力。
|
|
50
|
+
"""
|
|
51
|
+
# 获取 batch size 和 sequence length
|
|
52
|
+
bsz, seq_len = input_ids.shape
|
|
53
|
+
|
|
54
|
+
# 1. 确定每个 eot_token 的位置
|
|
55
|
+
# is_eot 是一个布尔张量,形状为 (bsz, seq_len)
|
|
56
|
+
is_eot = (input_ids == TrainerTools().tokenizer.end)
|
|
57
|
+
|
|
58
|
+
# 2. 为每个 token 分配一个文档 ID
|
|
59
|
+
# 我们使用 cumsum (累加和) 来创建递增的文档 ID。一个 token 所属的文档 ID,
|
|
60
|
+
# 取决于它前面有多少个 eot。
|
|
61
|
+
# 示例:
|
|
62
|
+
# input_ids: [[1, 2, 3, eot, 4, 5, eot]]
|
|
63
|
+
# is_eot: [F, F, F, T, F, F, T] -> [0, 0, 0, 1, 0, 0, 1]
|
|
64
|
+
# doc_ids_ending: [0, 0, 0, 1, 1, 1, 2] (cumsum 的结果)
|
|
65
|
+
# doc_ids: [0, 0, 0, 0, 1, 1, 1] (向右移位后的结果)
|
|
66
|
+
# 这个结果正确地将文档 0 分配给了前四个 token,将文档 1 分配给了后三个 token。
|
|
67
|
+
doc_ids_ending = torch.cumsum(is_eot, dim=-1)
|
|
68
|
+
doc_ids = F.pad(doc_ids_ending[:, :-1], (1, 0), value=0)
|
|
69
|
+
|
|
70
|
+
# 3. 通过比较 query 和 key 的文档 ID 来创建 mask
|
|
71
|
+
# 我们的目标是:当 query token 所在的文档 ID 大于 key token 所在的文档 ID 时,进行屏蔽。
|
|
72
|
+
# query_doc_ids 形状: (bsz, seq_len, 1)
|
|
73
|
+
# key_doc_ids 形状: (bsz, 1, seq_len)
|
|
74
|
+
query_doc_ids = doc_ids.unsqueeze(2)
|
|
75
|
+
key_doc_ids = doc_ids.unsqueeze(1)
|
|
76
|
+
|
|
77
|
+
# 利用 PyTorch 的广播机制,`query_doc_ids > key_doc_ids` 会创建一个
|
|
78
|
+
# 形状为 (bsz, seq_len, seq_len) 的布尔张量。
|
|
79
|
+
# 当 query 的文档 ID 大于 key 的文档 ID 时,值为 True,这正是我们需要屏蔽的位置。
|
|
80
|
+
boundary_mask = query_doc_ids > key_doc_ids
|
|
81
|
+
|
|
82
|
+
# 4. 将布尔 mask 转换为 attention 机制所需的浮点数 mask (-inf 和 0)
|
|
83
|
+
final_mask = torch.zeros(
|
|
84
|
+
(bsz, seq_len, seq_len), device=input_ids.device, dtype=dtype
|
|
85
|
+
)
|
|
86
|
+
final_mask.masked_fill_(boundary_mask, torch.finfo(dtype).min)
|
|
87
|
+
|
|
88
|
+
# 5. 增加一个维度以匹配 attention head 的输入要求 (bsz, num_heads, seq_len, seq_len)
|
|
89
|
+
# 这里我们只生成一个 mask,它可以被广播到所有的 head。
|
|
90
|
+
return final_mask.unsqueeze(1)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def generate_position_ids(input_ids: torch.Tensor):
|
|
94
|
+
"""
|
|
95
|
+
为打包序列生成 position_ids 张量。
|
|
96
|
+
|
|
97
|
+
参数:
|
|
98
|
+
input_ids (torch.Tensor): 输入的 token ID 张量 (batch_size, sequence_length)。
|
|
99
|
+
end_of_text_id (int): 代表文本结束的特殊 token ID。
|
|
100
|
+
|
|
101
|
+
返回:
|
|
102
|
+
torch.Tensor: 生成的 position_ids 张量。
|
|
103
|
+
"""
|
|
104
|
+
# 获取输入张量的形状
|
|
105
|
+
batch_size, seq_length = input_ids.shape
|
|
106
|
+
|
|
107
|
+
# 创建一个与输入形状相同,全为0的张量来存储position_ids
|
|
108
|
+
# 第一个token的位置永远是0,所以这个初始化是正确的
|
|
109
|
+
position_ids = torch.zeros_like(input_ids, dtype=torch.long)
|
|
110
|
+
|
|
111
|
+
# 从第二个时间步 (t=1) 开始遍历整个序列
|
|
112
|
+
for t in range(1, seq_length):
|
|
113
|
+
# 检查前一个时间步 (t-1) 的token是否为 EOT token
|
|
114
|
+
# 这会为批次中的每个序列生成一个布尔值
|
|
115
|
+
is_reset_token = (input_ids[:, t - 1] == TrainerTools().tokenizer.end)
|
|
116
|
+
|
|
117
|
+
# 获取前一个时间步的位置ID
|
|
118
|
+
prev_position_ids = position_ids[:, t - 1]
|
|
119
|
+
|
|
120
|
+
# 如果前一个token是EOT,当前位置重置为0;否则,在前一个位置上加1
|
|
121
|
+
# torch.where 会根据 is_reset_token 的布尔值进行选择
|
|
122
|
+
position_ids[:, t] = torch.where(is_reset_token, 0, prev_position_ids + 1)
|
|
123
|
+
|
|
124
|
+
return position_ids
|
|
125
|
+
|
|
126
|
+
|
|
17
127
|
def repeat_image_tok(
|
|
18
128
|
tokens: torch.Tensor,
|
|
19
129
|
tokens_per_image: int
|
|
@@ -43,43 +153,6 @@ def batch_repeat_image_tok(
|
|
|
43
153
|
return torch.stack(new_tokens, dim=0)
|
|
44
154
|
|
|
45
155
|
|
|
46
|
-
def _pad_sequence(batch_data):
|
|
47
|
-
# [[x,x,x], [y,y,y]]
|
|
48
|
-
inputs = pad_sequence(batch_data, batch_first=True, padding_value=TrainerTools().tokenizer.pad)
|
|
49
|
-
# crossEntropy默认的ignore_index是-100
|
|
50
|
-
labels = pad_sequence(batch_data, batch_first=True, padding_value=-100)
|
|
51
|
-
|
|
52
|
-
return inputs, labels
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
def _mask_prompt(labels):
|
|
56
|
-
tokenizer = TrainerTools().tokenizer
|
|
57
|
-
# 支持多轮会话的mask
|
|
58
|
-
for batch, label in enumerate(labels):
|
|
59
|
-
start_index = -1
|
|
60
|
-
for index, token in enumerate(label):
|
|
61
|
-
if token == tokenizer.system or token == tokenizer.user:
|
|
62
|
-
start_index = index
|
|
63
|
-
elif token == tokenizer.end and start_index != -1:
|
|
64
|
-
labels[batch, start_index:index + 1] = -100
|
|
65
|
-
start_index = -1
|
|
66
|
-
|
|
67
|
-
return labels
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
def _zero_pad_sequences(
|
|
71
|
-
sequences: list[torch.Tensor], side: str = "left"
|
|
72
|
-
) -> torch.Tensor:
|
|
73
|
-
assert side in ("left", "right")
|
|
74
|
-
max_len = max(seq.size(0) for seq in sequences)
|
|
75
|
-
padded_sequences = []
|
|
76
|
-
for seq in sequences:
|
|
77
|
-
pad_len = max_len - seq.size(0)
|
|
78
|
-
padding = (pad_len, 0) if side == "left" else (0, pad_len)
|
|
79
|
-
padded_sequences.append(F.pad(seq, padding))
|
|
80
|
-
return torch.stack(padded_sequences, dim=0)
|
|
81
|
-
|
|
82
|
-
|
|
83
156
|
def pretrain_collate_fn(batch_data):
|
|
84
157
|
inputs, labels = _pad_sequence(batch_data)
|
|
85
158
|
|
|
@@ -219,4 +292,41 @@ def join_batch(batch_data: list[dict]) -> dict:
|
|
|
219
292
|
data = None
|
|
220
293
|
result[key] = data
|
|
221
294
|
|
|
222
|
-
return result
|
|
295
|
+
return result
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def _pad_sequence(batch_data):
|
|
299
|
+
# [[x,x,x], [y,y,y]]
|
|
300
|
+
inputs = pad_sequence(batch_data, batch_first=True, padding_value=TrainerTools().tokenizer.pad)
|
|
301
|
+
# crossEntropy默认的ignore_index是-100
|
|
302
|
+
labels = pad_sequence(batch_data, batch_first=True, padding_value=-100)
|
|
303
|
+
|
|
304
|
+
return inputs, labels
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def _mask_prompt(labels):
|
|
308
|
+
tokenizer = TrainerTools().tokenizer
|
|
309
|
+
# 支持多轮会话的mask
|
|
310
|
+
for batch, label in enumerate(labels):
|
|
311
|
+
start_index = -1
|
|
312
|
+
for index, token in enumerate(label):
|
|
313
|
+
if token == tokenizer.system or token == tokenizer.user:
|
|
314
|
+
start_index = index
|
|
315
|
+
elif token == tokenizer.end and start_index != -1:
|
|
316
|
+
labels[batch, start_index:index + 1] = -100
|
|
317
|
+
start_index = -1
|
|
318
|
+
|
|
319
|
+
return labels
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def _zero_pad_sequences(
|
|
323
|
+
sequences: list[torch.Tensor], side: str = "left"
|
|
324
|
+
) -> torch.Tensor:
|
|
325
|
+
assert side in ("left", "right")
|
|
326
|
+
max_len = max(seq.size(0) for seq in sequences)
|
|
327
|
+
padded_sequences = []
|
|
328
|
+
for seq in sequences:
|
|
329
|
+
pad_len = max_len - seq.size(0)
|
|
330
|
+
padding = (pad_len, 0) if side == "left" else (0, pad_len)
|
|
331
|
+
padded_sequences.append(F.pad(seq, padding))
|
|
332
|
+
return torch.stack(padded_sequences, dim=0)
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
|
|
2
|
+
llm_trainer/checkpoint.py,sha256=gz31pZbbQvRTYrBhxV-MFaBAIFeqpe7rM6nFsjwT9lY,4328
|
|
3
|
+
llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
|
|
4
|
+
llm_trainer/dpo_trainer.py,sha256=_8ZwOKQH69c6Fa5Cey5hNep7XUoI4jPIXQaQcV3soGw,12367
|
|
5
|
+
llm_trainer/ds_checkpoint.py,sha256=Wzy7PvVVWR794-BW4uragWFTAkkgDvjvkF-qMdyB4fc,2141
|
|
6
|
+
llm_trainer/eval.py,sha256=ZyUfSo2Q8P-lrCdPEnGkoo5pGubd0AabREK5eMISRII,1109
|
|
7
|
+
llm_trainer/generate_utils.py,sha256=zX5218RX4ltahCQCZVVCWQghCWhKslPk2NUnl_CakIE,15050
|
|
8
|
+
llm_trainer/grpo_trainer.py,sha256=0iWvpuMI5CDNIjH08Dd1ihZFqDYenVnHACiMY2GLJtg,16449
|
|
9
|
+
llm_trainer/log.py,sha256=XwychwKF6gvFPhthCIZCAEUZ0G3DY3fiQrOHqPWsxz0,463
|
|
10
|
+
llm_trainer/loss.py,sha256=eYvOlCoguKnLvdGuqvQpGUoLVSADQ5coaU3DWYbJEdM,6811
|
|
11
|
+
llm_trainer/parallel.py,sha256=yjStV21DJ26yM8-0O6GTMxdFAcyShY5GsQWSZmbI7HU,4543
|
|
12
|
+
llm_trainer/parallel_ddp.py,sha256=Pob9vUlBZnkL4oP1Re11kFob7nufMSE96pn7m7fuOEM,1345
|
|
13
|
+
llm_trainer/parallel_ds.py,sha256=oy8RRxHud3rACWubFlJqqd0pjPEQhKeAPGPQUSdJX2c,1145
|
|
14
|
+
llm_trainer/parallel_none.py,sha256=TG6Pm829Dg-yQu-97O-EHV3FCARBlNcP47KkGFAs16E,676
|
|
15
|
+
llm_trainer/partition_utils.py,sha256=eEYNhfEIF4hGzZ3OLa6sEBIECz261drptEz_n7fZYtk,8396
|
|
16
|
+
llm_trainer/scheduler.py,sha256=LAI_0VxClsIQkix0bRoduRD4vPfVuIZDhZgTAT_KK8k,4901
|
|
17
|
+
llm_trainer/sft_trainer.py,sha256=LudTRIaqLQYy6ym6jjMX7v9xtFBJelrR3nnPCwb48nM,1821
|
|
18
|
+
llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
|
|
19
|
+
llm_trainer/tools.py,sha256=5op5qrjjkK-Lr9oes5VxIVnOVYOYGoAdlIJq9mPUf64,2637
|
|
20
|
+
llm_trainer/train_configs.py,sha256=U4hwXWKI6svDqiDOu6RPTitCzpxEYyjZUN6gwh_co8c,7510
|
|
21
|
+
llm_trainer/trainer.py,sha256=2TC2GJeoGd0fDE6CFodk1chsSkk0v0yO0wrFYim5t4g,27938
|
|
22
|
+
llm_trainer/utils.py,sha256=ox2fWtSOS7F2Nh7_FoHxuQgaps1jGW3q59VXz04wRuA,11491
|
|
23
|
+
project_llm_trainer-0.7.0.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
24
|
+
project_llm_trainer-0.7.0.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
|
|
25
|
+
project_llm_trainer-0.7.0.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
|
|
26
|
+
project_llm_trainer-0.7.0.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
|
|
27
|
+
project_llm_trainer-0.7.0.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
|
|
28
|
+
project_llm_trainer-0.7.0.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
29
|
+
project_llm_trainer-0.7.0.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
|
|
30
|
+
project_llm_trainer-0.7.0.dist-info/METADATA,sha256=Q_UU9xBZIIBFOmfQJg1708lFfYn4bu5FA0fuxJCCcxQ,195
|
|
31
|
+
project_llm_trainer-0.7.0.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
32
|
+
project_llm_trainer-0.7.0.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
33
|
+
project_llm_trainer-0.7.0.dist-info/RECORD,,
|
|
@@ -1,33 +0,0 @@
|
|
|
1
|
-
llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
|
|
2
|
-
llm_trainer/checkpoint.py,sha256=gz31pZbbQvRTYrBhxV-MFaBAIFeqpe7rM6nFsjwT9lY,4328
|
|
3
|
-
llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
|
|
4
|
-
llm_trainer/dpo_trainer.py,sha256=mETXpU1ZSasg1UM72wnh9NaoTuXBibuNuodfuW7u8Iw,12269
|
|
5
|
-
llm_trainer/ds_checkpoint.py,sha256=Wzy7PvVVWR794-BW4uragWFTAkkgDvjvkF-qMdyB4fc,2141
|
|
6
|
-
llm_trainer/eval.py,sha256=ZyUfSo2Q8P-lrCdPEnGkoo5pGubd0AabREK5eMISRII,1109
|
|
7
|
-
llm_trainer/generate_utils.py,sha256=wrZoG2g7CsOyG4sb3px9vURHQFV6_9j5kQmpFc5A8yg,15335
|
|
8
|
-
llm_trainer/grpo_trainer.py,sha256=-wbozslll_bcGUMqrbS0a73jhosyjc3oC3PHLSev6lw,16344
|
|
9
|
-
llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
|
|
10
|
-
llm_trainer/loss.py,sha256=eYvOlCoguKnLvdGuqvQpGUoLVSADQ5coaU3DWYbJEdM,6811
|
|
11
|
-
llm_trainer/parallel.py,sha256=G9X0FddIJwd9j-5XOknB4AlBe4G2W6fUCaQH6ycC2Fo,4490
|
|
12
|
-
llm_trainer/parallel_ddp.py,sha256=Pob9vUlBZnkL4oP1Re11kFob7nufMSE96pn7m7fuOEM,1345
|
|
13
|
-
llm_trainer/parallel_ds.py,sha256=oy8RRxHud3rACWubFlJqqd0pjPEQhKeAPGPQUSdJX2c,1145
|
|
14
|
-
llm_trainer/parallel_none.py,sha256=TG6Pm829Dg-yQu-97O-EHV3FCARBlNcP47KkGFAs16E,676
|
|
15
|
-
llm_trainer/partition_utils.py,sha256=eEYNhfEIF4hGzZ3OLa6sEBIECz261drptEz_n7fZYtk,8396
|
|
16
|
-
llm_trainer/scheduler.py,sha256=LAI_0VxClsIQkix0bRoduRD4vPfVuIZDhZgTAT_KK8k,4901
|
|
17
|
-
llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,1783
|
|
18
|
-
llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
|
|
19
|
-
llm_trainer/tools.py,sha256=yF17lp6oOfLe2XJeKDQ1juZcbv-6vFamJSLwEeArduA,2975
|
|
20
|
-
llm_trainer/train_configs.py,sha256=U4hwXWKI6svDqiDOu6RPTitCzpxEYyjZUN6gwh_co8c,7510
|
|
21
|
-
llm_trainer/trainer.py,sha256=Q821nlLDKRZVpaRoiZ7DiJplpAJRRLtvR_33FbClGA0,26729
|
|
22
|
-
llm_trainer/utils.py,sha256=LWNhyQ0NDEZ9mZtk2Ryvh6EulvHIaUGIflugSpqmeFI,6791
|
|
23
|
-
project_llm_trainer-0.6.0.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
24
|
-
project_llm_trainer-0.6.0.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
|
|
25
|
-
project_llm_trainer-0.6.0.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
|
|
26
|
-
project_llm_trainer-0.6.0.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
|
|
27
|
-
project_llm_trainer-0.6.0.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
|
|
28
|
-
project_llm_trainer-0.6.0.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
29
|
-
project_llm_trainer-0.6.0.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
|
|
30
|
-
project_llm_trainer-0.6.0.dist-info/METADATA,sha256=_F0QQHrdQNGXG8eDGRDsgEvdX6fYWXSDg5Ad089CXHk,195
|
|
31
|
-
project_llm_trainer-0.6.0.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
32
|
-
project_llm_trainer-0.6.0.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
33
|
-
project_llm_trainer-0.6.0.dist-info/RECORD,,
|
{project_llm_trainer-0.6.0.data → project_llm_trainer-0.7.0.data}/scripts/calc_intermediate_size
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|