project-llm-trainer 0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/__init__.py +6 -0
- llm_trainer/checkpoint.py +161 -0
- llm_trainer/dataset.py +140 -0
- llm_trainer/dcp.py +93 -0
- llm_trainer/dpo_trainer.py +300 -0
- llm_trainer/ds_checkpoint.py +61 -0
- llm_trainer/eval.py +86 -0
- llm_trainer/generate_utils.py +424 -0
- llm_trainer/grpo_trainer.py +393 -0
- llm_trainer/log.py +16 -0
- llm_trainer/loss.py +171 -0
- llm_trainer/parallel.py +146 -0
- llm_trainer/parallel_ddp.py +39 -0
- llm_trainer/parallel_ds.py +45 -0
- llm_trainer/parallel_fsdp.py +115 -0
- llm_trainer/parallel_none.py +28 -0
- llm_trainer/scheduler.py +138 -0
- llm_trainer/sft_trainer.py +39 -0
- llm_trainer/tokenizer.py +166 -0
- llm_trainer/tools.py +102 -0
- llm_trainer/train_configs.py +445 -0
- llm_trainer/trainer.py +569 -0
- llm_trainer/utils.py +262 -0
- project_llm_trainer-0.3.data/scripts/calc_intermediate_size +15 -0
- project_llm_trainer-0.3.data/scripts/ddp_train +12 -0
- project_llm_trainer-0.3.data/scripts/ds_train +12 -0
- project_llm_trainer-0.3.data/scripts/plot_loss +39 -0
- project_llm_trainer-0.3.data/scripts/plot_lr +41 -0
- project_llm_trainer-0.3.data/scripts/py_train +12 -0
- project_llm_trainer-0.3.data/scripts/smart_train +28 -0
- project_llm_trainer-0.3.dist-info/METADATA +9 -0
- project_llm_trainer-0.3.dist-info/RECORD +34 -0
- project_llm_trainer-0.3.dist-info/WHEEL +5 -0
- project_llm_trainer-0.3.dist-info/top_level.txt +1 -0
llm_trainer/tokenizer.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import warnings
|
|
3
|
+
from typing import List, Dict, Union
|
|
4
|
+
from transformers import Qwen2TokenizerFast
|
|
5
|
+
from transformers import AddedToken
|
|
6
|
+
from transformers import LlamaTokenizer, LlamaTokenizerFast
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
TOKEN_TYPE_QWEN = 'qwen'
|
|
10
|
+
TOKEN_TYPE_ZH_LLAMA = "zh_llama"
|
|
11
|
+
|
|
12
|
+
AVAILABLE_TOKEN_TYPES = [TOKEN_TYPE_QWEN, TOKEN_TYPE_ZH_LLAMA]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Tokenizer:
|
|
16
|
+
def __init__(self, token_type: str = TOKEN_TYPE_ZH_LLAMA):
|
|
17
|
+
super().__init__()
|
|
18
|
+
assert token_type in AVAILABLE_TOKEN_TYPES, 'token type is unavailable'
|
|
19
|
+
self.token_type = token_type
|
|
20
|
+
|
|
21
|
+
self.text_end = '</s>'
|
|
22
|
+
|
|
23
|
+
self.text_pad = '<pad>'
|
|
24
|
+
self.text_unk = '<unk>'
|
|
25
|
+
|
|
26
|
+
self.text_user = '<user>'
|
|
27
|
+
self.text_assistant = '<assistant>'
|
|
28
|
+
|
|
29
|
+
self.text_reasoning_start = '<reasoning>'
|
|
30
|
+
self.text_reasoning_end = '</reasoning>'
|
|
31
|
+
|
|
32
|
+
self.text_answer_start = '<answer>'
|
|
33
|
+
self.text_answer_end = '</answer>'
|
|
34
|
+
|
|
35
|
+
self.text_system = '<system>'
|
|
36
|
+
|
|
37
|
+
self.text_image = '<image>'
|
|
38
|
+
|
|
39
|
+
if token_type == TOKEN_TYPE_QWEN:
|
|
40
|
+
self.tokenizer = Qwen2TokenizerFast(
|
|
41
|
+
vocab_file=f"{os.environ['TOKEN_DIR']}qwen_vocab.json",
|
|
42
|
+
merges_file=f"{os.environ['TOKEN_DIR']}qwen_merges.txt",
|
|
43
|
+
unk_token=self.text_unk,
|
|
44
|
+
eos_token=self.text_end,
|
|
45
|
+
pad_token=self.text_pad
|
|
46
|
+
)
|
|
47
|
+
additional_special_tokens = [
|
|
48
|
+
AddedToken(self.text_user, lstrip=False, rstrip=False),
|
|
49
|
+
AddedToken(self.text_assistant, lstrip=False, rstrip=False),
|
|
50
|
+
AddedToken(self.text_reasoning_start, lstrip=False, rstrip=False),
|
|
51
|
+
AddedToken(self.text_reasoning_end, lstrip=False, rstrip=False),
|
|
52
|
+
AddedToken(self.text_answer_start, lstrip=False, rstrip=False),
|
|
53
|
+
AddedToken(self.text_answer_end, lstrip=False, rstrip=False),
|
|
54
|
+
AddedToken(self.text_system, lstrip=False, rstrip=False),
|
|
55
|
+
AddedToken(self.text_image, lstrip=False, rstrip=False),
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
self.tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
|
|
59
|
+
else:
|
|
60
|
+
self.tokenizer = LlamaTokenizerFast.from_pretrained(os.environ['TOKEN_DIR'])
|
|
61
|
+
# self.tokenizer = AutoTokenizer.from_pretrained(os.environ['TOKEN_DIR'])
|
|
62
|
+
# self.tokenizer = PreTrainedTokenizerFast.from_pretrained(os.environ['TOKEN_DIR'], trust_remote_code=True)
|
|
63
|
+
|
|
64
|
+
self.end = self.tokenizer.convert_tokens_to_ids(self.text_end)
|
|
65
|
+
|
|
66
|
+
self.pad = self.tokenizer.convert_tokens_to_ids(self.text_pad)
|
|
67
|
+
self.unk = self.tokenizer.convert_tokens_to_ids(self.text_unk)
|
|
68
|
+
|
|
69
|
+
self.user = self.tokenizer.convert_tokens_to_ids(self.text_user)
|
|
70
|
+
self.assistant = self.tokenizer.convert_tokens_to_ids(self.text_assistant)
|
|
71
|
+
|
|
72
|
+
self.reasoning_start = self.tokenizer.convert_tokens_to_ids(self.text_reasoning_start)
|
|
73
|
+
self.reasoning_end = self.tokenizer.convert_tokens_to_ids(self.text_reasoning_end)
|
|
74
|
+
|
|
75
|
+
self.answer_start = self.tokenizer.convert_tokens_to_ids(self.text_answer_start)
|
|
76
|
+
self.answer_end = self.tokenizer.convert_tokens_to_ids(self.text_answer_end)
|
|
77
|
+
|
|
78
|
+
self.system = self.tokenizer.convert_tokens_to_ids(self.text_system)
|
|
79
|
+
self.image = self.tokenizer.convert_tokens_to_ids(self.text_image)
|
|
80
|
+
|
|
81
|
+
self.vocab_size = len(self.tokenizer)
|
|
82
|
+
|
|
83
|
+
def encode(
|
|
84
|
+
self,
|
|
85
|
+
text: str,
|
|
86
|
+
unsqueeze: bool = False,
|
|
87
|
+
covert_tensor: bool = False
|
|
88
|
+
) -> Union[torch.Tensor, List[int]]:
|
|
89
|
+
# [x,x,x]
|
|
90
|
+
encoded = self.tokenizer.encode(text, add_special_tokens=False)
|
|
91
|
+
|
|
92
|
+
# if self.token_type == TOKEN_TYPE_MISTRAL:
|
|
93
|
+
# # 处理MISTRAL每句话前面都会增加一个29473的问题
|
|
94
|
+
# if encoded[0] == 29473:
|
|
95
|
+
# encoded = encoded[1:]
|
|
96
|
+
|
|
97
|
+
if unsqueeze:
|
|
98
|
+
# tensor: [[x,x,x]]
|
|
99
|
+
return torch.tensor(encoded).long().unsqueeze(0)
|
|
100
|
+
else:
|
|
101
|
+
# tensor: # [x,x,x]
|
|
102
|
+
if covert_tensor:
|
|
103
|
+
return torch.tensor(encoded).long()
|
|
104
|
+
|
|
105
|
+
return encoded
|
|
106
|
+
|
|
107
|
+
def decode(
|
|
108
|
+
self,
|
|
109
|
+
token: Union[torch.Tensor, List[int]],
|
|
110
|
+
skip_special_tokens: bool = False
|
|
111
|
+
) -> str:
|
|
112
|
+
return self.tokenizer.decode(token, skip_special_tokens=skip_special_tokens)
|
|
113
|
+
|
|
114
|
+
def batch_decode(
|
|
115
|
+
self,
|
|
116
|
+
tokens: Union[torch.Tensor, List[int], List[List[int]]],
|
|
117
|
+
skip_special_tokens: bool = False
|
|
118
|
+
) -> List[str]:
|
|
119
|
+
return self.tokenizer.batch_decode(tokens, skip_special_tokens=skip_special_tokens)
|
|
120
|
+
|
|
121
|
+
def encode_to_token(self, text: str, unsqueeze=True, covert_tensor=True):
|
|
122
|
+
warnings.warn('encode_to_token is deprecated. Please use `encode` instead.')
|
|
123
|
+
return self.encode(text, unsqueeze, covert_tensor)
|
|
124
|
+
|
|
125
|
+
def decode_to_text(self, token: torch.Tensor, skip_special_tokens: bool = False) -> str:
|
|
126
|
+
warnings.warn('decode_to_text is deprecated. Please use `decode` instead.')
|
|
127
|
+
return self.decode(token.squeeze(0), skip_special_tokens)
|
|
128
|
+
|
|
129
|
+
def apply_chat_template(
|
|
130
|
+
self,
|
|
131
|
+
conversations: List[Dict[str, str]],
|
|
132
|
+
tokenizer: bool = True,
|
|
133
|
+
add_answer_tag_for_assistant: bool = True,
|
|
134
|
+
unsqueeze=False,
|
|
135
|
+
covert_tensor=False
|
|
136
|
+
):
|
|
137
|
+
"""
|
|
138
|
+
[
|
|
139
|
+
{"role":"system", "content":"system prompt"},
|
|
140
|
+
{"role":"user", "content":"hello?"},
|
|
141
|
+
{"role":"assistant", "content":"hello"},
|
|
142
|
+
{"role":"user", "content":"hello hello?"},
|
|
143
|
+
{"role":"assistant", "reasoning":"thinking", "content":"hello hello"},
|
|
144
|
+
]
|
|
145
|
+
<system>{system_prompt}</s><user>hello?</s><assistant>hello</s><user>hello hello?</s><assistant><reasoning>thinking</reasoning><answer>hello hello</answer></s>
|
|
146
|
+
"""
|
|
147
|
+
|
|
148
|
+
chat_template = ''
|
|
149
|
+
support_roles = {'system': self.text_system, 'user': self.text_user, 'assistant': self.text_assistant}
|
|
150
|
+
for conversation in conversations:
|
|
151
|
+
role = conversation['role']
|
|
152
|
+
if role in support_roles:
|
|
153
|
+
content = conversation['content']
|
|
154
|
+
if add_answer_tag_for_assistant and role == 'assistant':
|
|
155
|
+
content = f"{self.text_answer_start}{content}{self.text_answer_end}"
|
|
156
|
+
|
|
157
|
+
if 'reasoning' in conversation:
|
|
158
|
+
content = f"{self.text_reasoning_start}{conversation['reasoning']}{self.text_reasoning_end}{content}"
|
|
159
|
+
|
|
160
|
+
chat_template = f"{chat_template}{support_roles[role]}{content}{self.text_end}"
|
|
161
|
+
|
|
162
|
+
if tokenizer:
|
|
163
|
+
return self.encode(chat_template, unsqueeze, covert_tensor)
|
|
164
|
+
|
|
165
|
+
return chat_template
|
|
166
|
+
|
llm_trainer/tools.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
import torch
|
|
4
|
+
from .tokenizer import Tokenizer
|
|
5
|
+
from .parallel_ds import DsParallel
|
|
6
|
+
from .parallel_fsdp import FsdpParallel
|
|
7
|
+
from .parallel_ddp import DdpParallel
|
|
8
|
+
from .parallel_none import NoneParallel
|
|
9
|
+
from .log import log
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
parallel_types = {
|
|
13
|
+
'ds': DsParallel,
|
|
14
|
+
'fsdp': FsdpParallel,
|
|
15
|
+
'ddp': DdpParallel,
|
|
16
|
+
'none': NoneParallel
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
dtypes = {
|
|
20
|
+
'float': torch.float,
|
|
21
|
+
'float16': torch.float16,
|
|
22
|
+
'float32': torch.float32,
|
|
23
|
+
'float64': torch.float64
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
class TrainerTools:
|
|
27
|
+
def __init__(self):
|
|
28
|
+
if not hasattr(TrainerTools, "_first_init"):
|
|
29
|
+
TrainerTools._first_init = True
|
|
30
|
+
|
|
31
|
+
self.parallel = self.new_parallel()
|
|
32
|
+
|
|
33
|
+
self.tokenizer = Tokenizer(os.environ.get('TOKENIZERS_TYPE', 'zh_llama'))
|
|
34
|
+
self.use_amp = 'cuda' in self.parallel.device and not isinstance(self.parallel, DsParallel)
|
|
35
|
+
|
|
36
|
+
dtype = os.environ.get('DTYPE', None)
|
|
37
|
+
self.dtype = dtypes[dtype] if dtype in dtypes else None
|
|
38
|
+
|
|
39
|
+
if not self.dtype:
|
|
40
|
+
self.dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
|
|
41
|
+
|
|
42
|
+
log(f'word_size={self.parallel.world_size},'
|
|
43
|
+
f' use_amp={self.use_amp},'
|
|
44
|
+
f' dtype={self.dtype}')
|
|
45
|
+
|
|
46
|
+
def new_parallel(self):
|
|
47
|
+
parallel_type = os.environ.get('PARALLEL_TYPE', 'none')
|
|
48
|
+
log(f'parallel_type={parallel_type}')
|
|
49
|
+
return parallel_types[parallel_type]()
|
|
50
|
+
|
|
51
|
+
def __new__(cls, *args, **kwargs):
|
|
52
|
+
if not hasattr(TrainerTools, "_instance"):
|
|
53
|
+
TrainerTools._instance = object.__new__(cls)
|
|
54
|
+
|
|
55
|
+
return TrainerTools._instance
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class FileDataset(ABC):
|
|
59
|
+
@abstractmethod
|
|
60
|
+
def __len__(self) -> int: ...
|
|
61
|
+
|
|
62
|
+
@abstractmethod
|
|
63
|
+
def __getitem__(self, idx) -> str: ...
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def estimate_data_size(
|
|
67
|
+
file_dataset: FileDataset,
|
|
68
|
+
max_position_embeddings: int,
|
|
69
|
+
type: str
|
|
70
|
+
) -> int:
|
|
71
|
+
"""
|
|
72
|
+
估计数据集大小
|
|
73
|
+
"""
|
|
74
|
+
data_size = 0
|
|
75
|
+
files_count = len(file_dataset)
|
|
76
|
+
|
|
77
|
+
if type == 'sft':
|
|
78
|
+
from .dataset import LineByLineTextDataset
|
|
79
|
+
for idx in range(files_count):
|
|
80
|
+
dataset = LineByLineTextDataset(file_dataset[idx], max_position_embeddings)
|
|
81
|
+
data_size += len(dataset)
|
|
82
|
+
elif type == 'dpo':
|
|
83
|
+
from .dataset import DPODataset
|
|
84
|
+
for idx in range(files_count):
|
|
85
|
+
dataset = DPODataset(file_dataset[idx], max_position_embeddings)
|
|
86
|
+
data_size += len(dataset)
|
|
87
|
+
elif type == 'grpo':
|
|
88
|
+
from .dataset import GRPORolloutDataset
|
|
89
|
+
for idx in range(files_count):
|
|
90
|
+
dataset = GRPORolloutDataset(file_dataset[idx])
|
|
91
|
+
data_size += len(dataset)
|
|
92
|
+
else:
|
|
93
|
+
from .dataset import TextDataset
|
|
94
|
+
for idx in range(files_count):
|
|
95
|
+
dataset = TextDataset(
|
|
96
|
+
file_dataset[idx],
|
|
97
|
+
max_position_embeddings,
|
|
98
|
+
max_position_embeddings
|
|
99
|
+
)
|
|
100
|
+
data_size += len(dataset)
|
|
101
|
+
|
|
102
|
+
return data_size
|