project-llm-trainer 0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

@@ -0,0 +1,166 @@
1
+ import os
2
+ import warnings
3
+ from typing import List, Dict, Union
4
+ from transformers import Qwen2TokenizerFast
5
+ from transformers import AddedToken
6
+ from transformers import LlamaTokenizer, LlamaTokenizerFast
7
+ import torch
8
+
9
+ TOKEN_TYPE_QWEN = 'qwen'
10
+ TOKEN_TYPE_ZH_LLAMA = "zh_llama"
11
+
12
+ AVAILABLE_TOKEN_TYPES = [TOKEN_TYPE_QWEN, TOKEN_TYPE_ZH_LLAMA]
13
+
14
+
15
+ class Tokenizer:
16
+ def __init__(self, token_type: str = TOKEN_TYPE_ZH_LLAMA):
17
+ super().__init__()
18
+ assert token_type in AVAILABLE_TOKEN_TYPES, 'token type is unavailable'
19
+ self.token_type = token_type
20
+
21
+ self.text_end = '</s>'
22
+
23
+ self.text_pad = '<pad>'
24
+ self.text_unk = '<unk>'
25
+
26
+ self.text_user = '<user>'
27
+ self.text_assistant = '<assistant>'
28
+
29
+ self.text_reasoning_start = '<reasoning>'
30
+ self.text_reasoning_end = '</reasoning>'
31
+
32
+ self.text_answer_start = '<answer>'
33
+ self.text_answer_end = '</answer>'
34
+
35
+ self.text_system = '<system>'
36
+
37
+ self.text_image = '<image>'
38
+
39
+ if token_type == TOKEN_TYPE_QWEN:
40
+ self.tokenizer = Qwen2TokenizerFast(
41
+ vocab_file=f"{os.environ['TOKEN_DIR']}qwen_vocab.json",
42
+ merges_file=f"{os.environ['TOKEN_DIR']}qwen_merges.txt",
43
+ unk_token=self.text_unk,
44
+ eos_token=self.text_end,
45
+ pad_token=self.text_pad
46
+ )
47
+ additional_special_tokens = [
48
+ AddedToken(self.text_user, lstrip=False, rstrip=False),
49
+ AddedToken(self.text_assistant, lstrip=False, rstrip=False),
50
+ AddedToken(self.text_reasoning_start, lstrip=False, rstrip=False),
51
+ AddedToken(self.text_reasoning_end, lstrip=False, rstrip=False),
52
+ AddedToken(self.text_answer_start, lstrip=False, rstrip=False),
53
+ AddedToken(self.text_answer_end, lstrip=False, rstrip=False),
54
+ AddedToken(self.text_system, lstrip=False, rstrip=False),
55
+ AddedToken(self.text_image, lstrip=False, rstrip=False),
56
+ ]
57
+
58
+ self.tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
59
+ else:
60
+ self.tokenizer = LlamaTokenizerFast.from_pretrained(os.environ['TOKEN_DIR'])
61
+ # self.tokenizer = AutoTokenizer.from_pretrained(os.environ['TOKEN_DIR'])
62
+ # self.tokenizer = PreTrainedTokenizerFast.from_pretrained(os.environ['TOKEN_DIR'], trust_remote_code=True)
63
+
64
+ self.end = self.tokenizer.convert_tokens_to_ids(self.text_end)
65
+
66
+ self.pad = self.tokenizer.convert_tokens_to_ids(self.text_pad)
67
+ self.unk = self.tokenizer.convert_tokens_to_ids(self.text_unk)
68
+
69
+ self.user = self.tokenizer.convert_tokens_to_ids(self.text_user)
70
+ self.assistant = self.tokenizer.convert_tokens_to_ids(self.text_assistant)
71
+
72
+ self.reasoning_start = self.tokenizer.convert_tokens_to_ids(self.text_reasoning_start)
73
+ self.reasoning_end = self.tokenizer.convert_tokens_to_ids(self.text_reasoning_end)
74
+
75
+ self.answer_start = self.tokenizer.convert_tokens_to_ids(self.text_answer_start)
76
+ self.answer_end = self.tokenizer.convert_tokens_to_ids(self.text_answer_end)
77
+
78
+ self.system = self.tokenizer.convert_tokens_to_ids(self.text_system)
79
+ self.image = self.tokenizer.convert_tokens_to_ids(self.text_image)
80
+
81
+ self.vocab_size = len(self.tokenizer)
82
+
83
+ def encode(
84
+ self,
85
+ text: str,
86
+ unsqueeze: bool = False,
87
+ covert_tensor: bool = False
88
+ ) -> Union[torch.Tensor, List[int]]:
89
+ # [x,x,x]
90
+ encoded = self.tokenizer.encode(text, add_special_tokens=False)
91
+
92
+ # if self.token_type == TOKEN_TYPE_MISTRAL:
93
+ # # 处理MISTRAL每句话前面都会增加一个29473的问题
94
+ # if encoded[0] == 29473:
95
+ # encoded = encoded[1:]
96
+
97
+ if unsqueeze:
98
+ # tensor: [[x,x,x]]
99
+ return torch.tensor(encoded).long().unsqueeze(0)
100
+ else:
101
+ # tensor: # [x,x,x]
102
+ if covert_tensor:
103
+ return torch.tensor(encoded).long()
104
+
105
+ return encoded
106
+
107
+ def decode(
108
+ self,
109
+ token: Union[torch.Tensor, List[int]],
110
+ skip_special_tokens: bool = False
111
+ ) -> str:
112
+ return self.tokenizer.decode(token, skip_special_tokens=skip_special_tokens)
113
+
114
+ def batch_decode(
115
+ self,
116
+ tokens: Union[torch.Tensor, List[int], List[List[int]]],
117
+ skip_special_tokens: bool = False
118
+ ) -> List[str]:
119
+ return self.tokenizer.batch_decode(tokens, skip_special_tokens=skip_special_tokens)
120
+
121
+ def encode_to_token(self, text: str, unsqueeze=True, covert_tensor=True):
122
+ warnings.warn('encode_to_token is deprecated. Please use `encode` instead.')
123
+ return self.encode(text, unsqueeze, covert_tensor)
124
+
125
+ def decode_to_text(self, token: torch.Tensor, skip_special_tokens: bool = False) -> str:
126
+ warnings.warn('decode_to_text is deprecated. Please use `decode` instead.')
127
+ return self.decode(token.squeeze(0), skip_special_tokens)
128
+
129
+ def apply_chat_template(
130
+ self,
131
+ conversations: List[Dict[str, str]],
132
+ tokenizer: bool = True,
133
+ add_answer_tag_for_assistant: bool = True,
134
+ unsqueeze=False,
135
+ covert_tensor=False
136
+ ):
137
+ """
138
+ [
139
+ {"role":"system", "content":"system prompt"},
140
+ {"role":"user", "content":"hello?"},
141
+ {"role":"assistant", "content":"hello"},
142
+ {"role":"user", "content":"hello hello?"},
143
+ {"role":"assistant", "reasoning":"thinking", "content":"hello hello"},
144
+ ]
145
+ <system>{system_prompt}</s><user>hello?</s><assistant>hello</s><user>hello hello?</s><assistant><reasoning>thinking</reasoning><answer>hello hello</answer></s>
146
+ """
147
+
148
+ chat_template = ''
149
+ support_roles = {'system': self.text_system, 'user': self.text_user, 'assistant': self.text_assistant}
150
+ for conversation in conversations:
151
+ role = conversation['role']
152
+ if role in support_roles:
153
+ content = conversation['content']
154
+ if add_answer_tag_for_assistant and role == 'assistant':
155
+ content = f"{self.text_answer_start}{content}{self.text_answer_end}"
156
+
157
+ if 'reasoning' in conversation:
158
+ content = f"{self.text_reasoning_start}{conversation['reasoning']}{self.text_reasoning_end}{content}"
159
+
160
+ chat_template = f"{chat_template}{support_roles[role]}{content}{self.text_end}"
161
+
162
+ if tokenizer:
163
+ return self.encode(chat_template, unsqueeze, covert_tensor)
164
+
165
+ return chat_template
166
+
llm_trainer/tools.py ADDED
@@ -0,0 +1,102 @@
1
+ import os
2
+ from abc import ABC, abstractmethod
3
+ import torch
4
+ from .tokenizer import Tokenizer
5
+ from .parallel_ds import DsParallel
6
+ from .parallel_fsdp import FsdpParallel
7
+ from .parallel_ddp import DdpParallel
8
+ from .parallel_none import NoneParallel
9
+ from .log import log
10
+
11
+
12
+ parallel_types = {
13
+ 'ds': DsParallel,
14
+ 'fsdp': FsdpParallel,
15
+ 'ddp': DdpParallel,
16
+ 'none': NoneParallel
17
+ }
18
+
19
+ dtypes = {
20
+ 'float': torch.float,
21
+ 'float16': torch.float16,
22
+ 'float32': torch.float32,
23
+ 'float64': torch.float64
24
+ }
25
+
26
+ class TrainerTools:
27
+ def __init__(self):
28
+ if not hasattr(TrainerTools, "_first_init"):
29
+ TrainerTools._first_init = True
30
+
31
+ self.parallel = self.new_parallel()
32
+
33
+ self.tokenizer = Tokenizer(os.environ.get('TOKENIZERS_TYPE', 'zh_llama'))
34
+ self.use_amp = 'cuda' in self.parallel.device and not isinstance(self.parallel, DsParallel)
35
+
36
+ dtype = os.environ.get('DTYPE', None)
37
+ self.dtype = dtypes[dtype] if dtype in dtypes else None
38
+
39
+ if not self.dtype:
40
+ self.dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
41
+
42
+ log(f'word_size={self.parallel.world_size},'
43
+ f' use_amp={self.use_amp},'
44
+ f' dtype={self.dtype}')
45
+
46
+ def new_parallel(self):
47
+ parallel_type = os.environ.get('PARALLEL_TYPE', 'none')
48
+ log(f'parallel_type={parallel_type}')
49
+ return parallel_types[parallel_type]()
50
+
51
+ def __new__(cls, *args, **kwargs):
52
+ if not hasattr(TrainerTools, "_instance"):
53
+ TrainerTools._instance = object.__new__(cls)
54
+
55
+ return TrainerTools._instance
56
+
57
+
58
+ class FileDataset(ABC):
59
+ @abstractmethod
60
+ def __len__(self) -> int: ...
61
+
62
+ @abstractmethod
63
+ def __getitem__(self, idx) -> str: ...
64
+
65
+
66
+ def estimate_data_size(
67
+ file_dataset: FileDataset,
68
+ max_position_embeddings: int,
69
+ type: str
70
+ ) -> int:
71
+ """
72
+ 估计数据集大小
73
+ """
74
+ data_size = 0
75
+ files_count = len(file_dataset)
76
+
77
+ if type == 'sft':
78
+ from .dataset import LineByLineTextDataset
79
+ for idx in range(files_count):
80
+ dataset = LineByLineTextDataset(file_dataset[idx], max_position_embeddings)
81
+ data_size += len(dataset)
82
+ elif type == 'dpo':
83
+ from .dataset import DPODataset
84
+ for idx in range(files_count):
85
+ dataset = DPODataset(file_dataset[idx], max_position_embeddings)
86
+ data_size += len(dataset)
87
+ elif type == 'grpo':
88
+ from .dataset import GRPORolloutDataset
89
+ for idx in range(files_count):
90
+ dataset = GRPORolloutDataset(file_dataset[idx])
91
+ data_size += len(dataset)
92
+ else:
93
+ from .dataset import TextDataset
94
+ for idx in range(files_count):
95
+ dataset = TextDataset(
96
+ file_dataset[idx],
97
+ max_position_embeddings,
98
+ max_position_embeddings
99
+ )
100
+ data_size += len(dataset)
101
+
102
+ return data_size