project-llm-trainer 0.12.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_trainer/__init__.py +13 -0
- llm_trainer/base_trainer.py +683 -0
- llm_trainer/checkpoint.py +126 -0
- llm_trainer/dataset.py +335 -0
- llm_trainer/dpo_trainer.py +297 -0
- llm_trainer/ds_checkpoint.py +63 -0
- llm_trainer/eval.py +33 -0
- llm_trainer/generate_utils.py +450 -0
- llm_trainer/grpo_trainer.py +385 -0
- llm_trainer/log.py +65 -0
- llm_trainer/loss.py +268 -0
- llm_trainer/parallel.py +220 -0
- llm_trainer/partition_utils.py +219 -0
- llm_trainer/ppo_trainer.py +521 -0
- llm_trainer/scheduler.py +179 -0
- llm_trainer/sft_trainer.py +97 -0
- llm_trainer/tokenizer.py +162 -0
- llm_trainer/tools.py +116 -0
- llm_trainer/train_configs.py +324 -0
- llm_trainer/trainer.py +34 -0
- llm_trainer/utils.py +547 -0
- project_llm_trainer-0.12.3.data/scripts/calc_intermediate_size +15 -0
- project_llm_trainer-0.12.3.data/scripts/ddp_train +21 -0
- project_llm_trainer-0.12.3.data/scripts/ds_train +17 -0
- project_llm_trainer-0.12.3.data/scripts/plot_log +69 -0
- project_llm_trainer-0.12.3.data/scripts/plot_lr +45 -0
- project_llm_trainer-0.12.3.data/scripts/py_train +12 -0
- project_llm_trainer-0.12.3.data/scripts/smart_train +37 -0
- project_llm_trainer-0.12.3.dist-info/METADATA +9 -0
- project_llm_trainer-0.12.3.dist-info/RECORD +32 -0
- project_llm_trainer-0.12.3.dist-info/WHEEL +5 -0
- project_llm_trainer-0.12.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
from typing import Optional, Tuple, List
|
|
2
|
+
from torch.utils.data import Dataset
|
|
3
|
+
|
|
4
|
+
from llm_model import (
|
|
5
|
+
VLMConfig,
|
|
6
|
+
LlmModel,
|
|
7
|
+
VlmModel
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
from .base_trainer import BaseTrainer
|
|
11
|
+
from .train_configs import TrainConfig
|
|
12
|
+
from .dataset import SFTDataset
|
|
13
|
+
from .utils import get_sft_collate_fn
|
|
14
|
+
from .tools import TrainerTools
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SFTTrainer(BaseTrainer):
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
*,
|
|
21
|
+
train_config: TrainConfig,
|
|
22
|
+
eval_prompts: List[str],
|
|
23
|
+
eval_image_tags: Optional[List[str]] = None
|
|
24
|
+
):
|
|
25
|
+
self.sft_config = train_config.sft_config
|
|
26
|
+
self.pixel_values_provider = self.sft_config.pixel_values_provider
|
|
27
|
+
self.eval_image_tags = eval_image_tags
|
|
28
|
+
|
|
29
|
+
super().__init__(
|
|
30
|
+
train_config=train_config,
|
|
31
|
+
eval_prompts=eval_prompts,
|
|
32
|
+
kd_config=self.sft_config.kd_config,
|
|
33
|
+
gradient_accumulation_steps=self.sft_config.gradient_accumulation_steps
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
if isinstance(train_config.model_config, VLMConfig):
|
|
37
|
+
self.pixel_values_provider = self.sft_config.pixel_values_provider
|
|
38
|
+
else:
|
|
39
|
+
self.pixel_values_provider = None
|
|
40
|
+
|
|
41
|
+
def _new_model(self, train_config: TrainConfig):
|
|
42
|
+
if isinstance(train_config.model_config, VLMConfig):
|
|
43
|
+
return VlmModel(train_config.model_config)
|
|
44
|
+
else:
|
|
45
|
+
return LlmModel(train_config.model_config)
|
|
46
|
+
|
|
47
|
+
def _check_freeze_llm_model(self, model):
|
|
48
|
+
# freeze llm model for vlm training
|
|
49
|
+
if self.sft_config.freeze_llm_model:
|
|
50
|
+
for name, param in model.named_parameters():
|
|
51
|
+
if not any(sub_module in name for sub_module in ['multi_modal_projector']):
|
|
52
|
+
param.requires_grad = False
|
|
53
|
+
|
|
54
|
+
# model.embed_tokens.eval()
|
|
55
|
+
# model.layers.eval()
|
|
56
|
+
# model.head_norm.eval()
|
|
57
|
+
# model.lm_head.eval()
|
|
58
|
+
|
|
59
|
+
def _convert_train_args(self) -> Tuple[dict, dict, dict]:
|
|
60
|
+
sft_collate_fn = get_sft_collate_fn(self.sft_config.mask_prompt)
|
|
61
|
+
parallel_kwargs, data_loader_kwargs, sampler_kwargs = super()._convert_train_args()
|
|
62
|
+
data_loader_kwargs.update({"collate_fn": sft_collate_fn})
|
|
63
|
+
|
|
64
|
+
return parallel_kwargs, data_loader_kwargs, sampler_kwargs
|
|
65
|
+
|
|
66
|
+
def _get_pixel_values(self, batch_data):
|
|
67
|
+
if self.pixel_values_provider and 'image_tags' in batch_data:
|
|
68
|
+
image_tags = batch_data['image_tags']
|
|
69
|
+
return self.pixel_values_provider(image_tags).to(TrainerTools().parallel.device)
|
|
70
|
+
|
|
71
|
+
return None
|
|
72
|
+
|
|
73
|
+
def _get_eval_pixel_values_and_tokens_count(self, eval_idx):
|
|
74
|
+
if not self.eval_image_tags:
|
|
75
|
+
return None, None
|
|
76
|
+
|
|
77
|
+
eval_image_tag = self.eval_image_tags[eval_idx]
|
|
78
|
+
if isinstance(self.train_config.model_config, VLMConfig) and self.pixel_values_provider and eval_image_tag:
|
|
79
|
+
return self.pixel_values_provider([eval_image_tag]), self.train_config.model_config.tokens_per_image
|
|
80
|
+
|
|
81
|
+
return None, None
|
|
82
|
+
|
|
83
|
+
def _create_dataset(self, file_idx) -> Tuple[Dataset, str]:
|
|
84
|
+
file_path = self.train_config.file_dataset[file_idx]
|
|
85
|
+
max_seq_len = self.train_config.max_seq_len
|
|
86
|
+
|
|
87
|
+
image_tag_file_path = None
|
|
88
|
+
tokens_per_image = -1
|
|
89
|
+
|
|
90
|
+
if isinstance(self.train_config.model_config, VLMConfig):
|
|
91
|
+
if self.sft_config.image_tags_file_dataset:
|
|
92
|
+
image_tag_file_path = self.sft_config.image_tags_file_dataset[file_idx]
|
|
93
|
+
|
|
94
|
+
if self.train_config.model_config.tokens_per_image:
|
|
95
|
+
tokens_per_image = self.train_config.model_config.tokens_per_image
|
|
96
|
+
|
|
97
|
+
return SFTDataset(file_path, max_seq_len, image_tag_file_path, tokens_per_image), file_path
|
llm_trainer/tokenizer.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import warnings
|
|
3
|
+
from typing import List, Dict, Union
|
|
4
|
+
from transformers import AutoTokenizer
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Tokenizer:
|
|
9
|
+
def __init__(self):
|
|
10
|
+
self.tokenizer = AutoTokenizer.from_pretrained(os.environ['TOKEN_DIR'])
|
|
11
|
+
|
|
12
|
+
self.text_end = '</s>'
|
|
13
|
+
|
|
14
|
+
self.text_pad = '<pad>'
|
|
15
|
+
self.text_unk = '<unk>'
|
|
16
|
+
|
|
17
|
+
self.text_user = '<user>'
|
|
18
|
+
self.text_assistant = '<assistant>'
|
|
19
|
+
|
|
20
|
+
self.text_think_start = '<think>'
|
|
21
|
+
self.text_think_end = '</think>'
|
|
22
|
+
|
|
23
|
+
self.text_answer_start = '<answer>'
|
|
24
|
+
self.text_answer_end = '</answer>'
|
|
25
|
+
|
|
26
|
+
self.text_system = '<system>'
|
|
27
|
+
|
|
28
|
+
self.text_image = '<image>'
|
|
29
|
+
|
|
30
|
+
self.end = self.tokenizer.convert_tokens_to_ids(self.text_end)
|
|
31
|
+
|
|
32
|
+
self.pad = self.tokenizer.convert_tokens_to_ids(self.text_pad)
|
|
33
|
+
self.unk = self.tokenizer.convert_tokens_to_ids(self.text_unk)
|
|
34
|
+
|
|
35
|
+
self.user = self.tokenizer.convert_tokens_to_ids(self.text_user)
|
|
36
|
+
self.assistant = self.tokenizer.convert_tokens_to_ids(self.text_assistant)
|
|
37
|
+
|
|
38
|
+
self.think_start = self.tokenizer.convert_tokens_to_ids(self.text_think_start)
|
|
39
|
+
self.think_end = self.tokenizer.convert_tokens_to_ids(self.text_think_end)
|
|
40
|
+
|
|
41
|
+
self.answer_start = self.tokenizer.convert_tokens_to_ids(self.text_answer_start)
|
|
42
|
+
self.answer_end = self.tokenizer.convert_tokens_to_ids(self.text_answer_end)
|
|
43
|
+
|
|
44
|
+
self.system = self.tokenizer.convert_tokens_to_ids(self.text_system)
|
|
45
|
+
self.image = self.tokenizer.convert_tokens_to_ids(self.text_image)
|
|
46
|
+
|
|
47
|
+
self.vocab_size = len(self.tokenizer)
|
|
48
|
+
|
|
49
|
+
def encode(
|
|
50
|
+
self,
|
|
51
|
+
text: str,
|
|
52
|
+
unsqueeze: bool = False,
|
|
53
|
+
covert_tensor: bool = False
|
|
54
|
+
) -> Union[torch.Tensor, List[int]]:
|
|
55
|
+
# [x,x,x]
|
|
56
|
+
encoded = self.tokenizer.encode(text, add_special_tokens=False)
|
|
57
|
+
|
|
58
|
+
if unsqueeze:
|
|
59
|
+
# tensor: [[x,x,x]]
|
|
60
|
+
return torch.tensor(encoded, dtype=torch.long).unsqueeze(0)
|
|
61
|
+
else:
|
|
62
|
+
# tensor: # [x,x,x]
|
|
63
|
+
if covert_tensor:
|
|
64
|
+
return torch.tensor(encoded, dtype=torch.long)
|
|
65
|
+
|
|
66
|
+
return encoded
|
|
67
|
+
|
|
68
|
+
def batch_encode(
|
|
69
|
+
self,
|
|
70
|
+
text: List[str],
|
|
71
|
+
padding = False,
|
|
72
|
+
truncation = False,
|
|
73
|
+
covert_tensor: bool = False,
|
|
74
|
+
return_attention_mask: bool = False
|
|
75
|
+
) -> Union[torch.Tensor, List[List[int]]]:
|
|
76
|
+
encoded = self.tokenizer(
|
|
77
|
+
text,
|
|
78
|
+
padding=padding,
|
|
79
|
+
truncation=truncation,
|
|
80
|
+
return_attention_mask=return_attention_mask
|
|
81
|
+
)['input_ids']
|
|
82
|
+
|
|
83
|
+
if covert_tensor:
|
|
84
|
+
encoded = torch.tensor(encoded, dtype=torch.long)
|
|
85
|
+
|
|
86
|
+
return encoded
|
|
87
|
+
|
|
88
|
+
def decode(
|
|
89
|
+
self,
|
|
90
|
+
token: Union[torch.Tensor, List[int]],
|
|
91
|
+
skip_special_tokens: bool = False
|
|
92
|
+
) -> str:
|
|
93
|
+
return self.tokenizer.decode(token, skip_special_tokens=skip_special_tokens)
|
|
94
|
+
|
|
95
|
+
def batch_decode(
|
|
96
|
+
self,
|
|
97
|
+
tokens: Union[torch.Tensor, List[int], List[List[int]]],
|
|
98
|
+
skip_special_tokens: bool = False
|
|
99
|
+
) -> List[str]:
|
|
100
|
+
return self.tokenizer.batch_decode(tokens, skip_special_tokens=skip_special_tokens)
|
|
101
|
+
|
|
102
|
+
def encode_to_token(self, text: str, unsqueeze=True, covert_tensor=True):
|
|
103
|
+
warnings.warn('encode_to_token is deprecated. Please use `encode` instead.')
|
|
104
|
+
return self.encode(text, unsqueeze, covert_tensor)
|
|
105
|
+
|
|
106
|
+
def decode_to_text(self, token: torch.Tensor, skip_special_tokens: bool = False) -> str:
|
|
107
|
+
warnings.warn('decode_to_text is deprecated. Please use `decode` instead.')
|
|
108
|
+
return self.decode(token.squeeze(0), skip_special_tokens)
|
|
109
|
+
|
|
110
|
+
def apply_chat_template(
|
|
111
|
+
self,
|
|
112
|
+
conversations: List[Dict[str, str]],
|
|
113
|
+
tokenizer: bool = True,
|
|
114
|
+
add_answer_tag_for_assistant: bool = True,
|
|
115
|
+
unsqueeze=False,
|
|
116
|
+
covert_tensor=False
|
|
117
|
+
):
|
|
118
|
+
"""
|
|
119
|
+
[
|
|
120
|
+
{"role":"system", "content":"system prompt"},
|
|
121
|
+
{"role":"user", "content":"hello?"},
|
|
122
|
+
{"role":"assistant", "content":"hello"},
|
|
123
|
+
{"role":"user", "content":"hello hello?"},
|
|
124
|
+
{"role":"assistant", "think":"thinking", "content":"hello hello"},
|
|
125
|
+
]
|
|
126
|
+
<system>{system_prompt}</s><user>hello?</s><assistant>hello</s><user>hello hello?</s><assistant><think>thinking</think><answer>hello hello</answer></s>
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
chat_template = ''
|
|
130
|
+
support_roles = {'system': self.text_system, 'user': self.text_user, 'assistant': self.text_assistant}
|
|
131
|
+
for conversation in conversations:
|
|
132
|
+
role = conversation['role']
|
|
133
|
+
if role in support_roles:
|
|
134
|
+
content = conversation['content']
|
|
135
|
+
if add_answer_tag_for_assistant and role == 'assistant':
|
|
136
|
+
content = f"{self.text_answer_start}{content}{self.text_answer_end}"
|
|
137
|
+
|
|
138
|
+
if 'think' in conversation:
|
|
139
|
+
content = f"{self.text_think_start}{conversation['think']}{self.text_think_end}{content}"
|
|
140
|
+
|
|
141
|
+
chat_template = f"{chat_template}{support_roles[role]}{content}{self.text_end}"
|
|
142
|
+
|
|
143
|
+
if tokenizer:
|
|
144
|
+
return self.encode(chat_template, unsqueeze, covert_tensor)
|
|
145
|
+
|
|
146
|
+
return chat_template
|
|
147
|
+
|
|
148
|
+
def get_special_tokens_dict(self):
|
|
149
|
+
return {
|
|
150
|
+
self.text_end: self.end,
|
|
151
|
+
self.text_pad: self.pad,
|
|
152
|
+
self.text_unk: self.unk,
|
|
153
|
+
self.text_user: self.user,
|
|
154
|
+
self.text_assistant: self.assistant,
|
|
155
|
+
self.text_think_start: self.think_start,
|
|
156
|
+
self.text_think_end: self.think_end,
|
|
157
|
+
self.text_answer_start: self.answer_start,
|
|
158
|
+
self.text_answer_end: self.answer_end,
|
|
159
|
+
self.text_system: self.system,
|
|
160
|
+
self.text_image: self.image,
|
|
161
|
+
}
|
|
162
|
+
|
llm_trainer/tools.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
import torch
|
|
4
|
+
from .tokenizer import Tokenizer
|
|
5
|
+
from .parallel import DsParallel, DdpParallel, NoneParallel
|
|
6
|
+
from .log import Logger
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
parallel_types = {
|
|
10
|
+
'ds': DsParallel,
|
|
11
|
+
'ddp': DdpParallel,
|
|
12
|
+
'none': NoneParallel
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
dtypes = {
|
|
16
|
+
'float': torch.float,
|
|
17
|
+
'float16': torch.float16,
|
|
18
|
+
'float32': torch.float32,
|
|
19
|
+
'float64': torch.float64
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
class TrainerTools:
|
|
23
|
+
def __init__(self):
|
|
24
|
+
if not hasattr(TrainerTools, "_first_init"):
|
|
25
|
+
TrainerTools._first_init = True
|
|
26
|
+
|
|
27
|
+
self.parallel = self._new_parallel()
|
|
28
|
+
|
|
29
|
+
self.tokenizer = Tokenizer()
|
|
30
|
+
self.use_amp = 'cuda' in self.parallel.device and not isinstance(self.parallel, DsParallel)
|
|
31
|
+
|
|
32
|
+
Logger.std_log(f'word_size={self.parallel.world_size}, use_amp={self.use_amp}')
|
|
33
|
+
|
|
34
|
+
def _new_parallel(self):
|
|
35
|
+
parallel_type = os.environ.get('PARALLEL_TYPE', 'none')
|
|
36
|
+
Logger.std_log(f'parallel_type={parallel_type}')
|
|
37
|
+
return parallel_types[parallel_type]()
|
|
38
|
+
|
|
39
|
+
def __new__(cls, *args, **kwargs):
|
|
40
|
+
if not hasattr(TrainerTools, "_instance"):
|
|
41
|
+
TrainerTools._instance = object.__new__(cls)
|
|
42
|
+
|
|
43
|
+
return TrainerTools._instance
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class FileDataset(ABC):
|
|
47
|
+
@abstractmethod
|
|
48
|
+
def __len__(self) -> int: ...
|
|
49
|
+
|
|
50
|
+
@abstractmethod
|
|
51
|
+
def __getitem__(self, idx) -> str: ...
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def estimate_data_size(
|
|
55
|
+
file_dataset: FileDataset,
|
|
56
|
+
max_seq_len: int,
|
|
57
|
+
type: str
|
|
58
|
+
) -> int:
|
|
59
|
+
"""
|
|
60
|
+
估计数据集大小
|
|
61
|
+
"""
|
|
62
|
+
data_size = 0
|
|
63
|
+
files_count = len(file_dataset)
|
|
64
|
+
|
|
65
|
+
if type == 'sft':
|
|
66
|
+
from .dataset import SFTDataset
|
|
67
|
+
for idx in range(files_count):
|
|
68
|
+
dataset = SFTDataset(file_dataset[idx], max_seq_len)
|
|
69
|
+
data_size += len(dataset)
|
|
70
|
+
elif type == 'dpo':
|
|
71
|
+
from .dataset import DPODataset
|
|
72
|
+
for idx in range(files_count):
|
|
73
|
+
dataset = DPODataset(file_dataset[idx], max_seq_len)
|
|
74
|
+
data_size += len(dataset)
|
|
75
|
+
elif type == 'grpo' or type == 'ppo':
|
|
76
|
+
from .dataset import RLDataset
|
|
77
|
+
for idx in range(files_count):
|
|
78
|
+
dataset = RLDataset(file_dataset[idx])
|
|
79
|
+
data_size += len(dataset)
|
|
80
|
+
else:
|
|
81
|
+
from .dataset import PretrainDataset
|
|
82
|
+
for idx in range(files_count):
|
|
83
|
+
dataset = PretrainDataset(
|
|
84
|
+
file_dataset[idx],
|
|
85
|
+
max_seq_len,
|
|
86
|
+
max_seq_len
|
|
87
|
+
)
|
|
88
|
+
data_size += len(dataset)
|
|
89
|
+
|
|
90
|
+
return data_size
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def extract_policy_weights_from_ppo(model_config, ppo_weights):
|
|
94
|
+
from llm_model import LlmModel
|
|
95
|
+
from .ppo_trainer import PolicyAndValueModelWrapper, ValueModel
|
|
96
|
+
|
|
97
|
+
policy_model = LlmModel(model_config)
|
|
98
|
+
value_model = ValueModel(LlmModel(model_config))
|
|
99
|
+
|
|
100
|
+
wrapper = PolicyAndValueModelWrapper(policy_model, value_model)
|
|
101
|
+
wrapper.load_state_dict(ppo_weights)
|
|
102
|
+
|
|
103
|
+
return wrapper.policy_model.state_dict()
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def extract_value_weights_from_ppo(model_config, ppo_weights):
|
|
107
|
+
from llm_model import LlmModel
|
|
108
|
+
from .ppo_trainer import PolicyAndValueModelWrapper, ValueModel
|
|
109
|
+
|
|
110
|
+
policy_model = LlmModel(model_config)
|
|
111
|
+
value_model = ValueModel(LlmModel(model_config))
|
|
112
|
+
|
|
113
|
+
wrapper = PolicyAndValueModelWrapper(policy_model, value_model)
|
|
114
|
+
wrapper.load_state_dict(ppo_weights)
|
|
115
|
+
|
|
116
|
+
return wrapper.value_model.state_dict()
|
|
@@ -0,0 +1,324 @@
|
|
|
1
|
+
from typing import Optional, Union, Callable, List, Mapping, Any, Tuple
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from llm_model import ModelConfig, VLMConfig
|
|
6
|
+
from .tools import FileDataset
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass(kw_only=True)
|
|
10
|
+
class DsOffloadConfig:
|
|
11
|
+
device: str = 'cpu'
|
|
12
|
+
pin_memory: bool = True
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass(kw_only=True)
|
|
16
|
+
class DsActivationCheckpointingConfig:
|
|
17
|
+
partition_activations: bool = True
|
|
18
|
+
cpu_checkpointing: bool = False
|
|
19
|
+
contiguous_memory_optimization: bool = True
|
|
20
|
+
number_checkpoints: Optional[int] = None
|
|
21
|
+
synchronize_checkpoint_boundary: bool = False
|
|
22
|
+
profile: bool = False
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass(kw_only=True)
|
|
26
|
+
class DsZeROConfig:
|
|
27
|
+
stage: int
|
|
28
|
+
allgather_partitions: Optional[bool] = True
|
|
29
|
+
allgather_bucket_size: Optional[int] = 5e8
|
|
30
|
+
overlap_comm: Optional[bool] = True
|
|
31
|
+
reduce_scatter: Optional[bool] = True
|
|
32
|
+
reduce_bucket_size: Optional[Union[str, int]] = 5e8
|
|
33
|
+
contiguous_gradients: Optional[bool] = True
|
|
34
|
+
|
|
35
|
+
@dataclass(kw_only=True)
|
|
36
|
+
class DsZero0Config(DsZeROConfig):
|
|
37
|
+
stage: int = field(default=0, init=False)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass(kw_only=True)
|
|
41
|
+
class DsZero1Config(DsZeROConfig):
|
|
42
|
+
stage: int = field(default=1, init=False)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass(kw_only=True)
|
|
46
|
+
class DsZero2Config(DsZeROConfig):
|
|
47
|
+
stage: int = field(default=2, init=False)
|
|
48
|
+
offload_optimizer: Optional[DsOffloadConfig] = None
|
|
49
|
+
offload_param: Optional[DsOffloadConfig] = None
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass(kw_only=True)
|
|
53
|
+
class DsZero3Config(DsZeROConfig):
|
|
54
|
+
stage: int = field(default=3, init=False)
|
|
55
|
+
sub_group_size: Optional[int] = 1e9
|
|
56
|
+
stage3_prefetch_bucket_size: Optional[Union[str, int]] = 'auto'
|
|
57
|
+
stage3_param_persistence_threshold: Optional[Union[str, int]] = 'auto'
|
|
58
|
+
stage3_max_live_parameters: Optional[int] = 1e9
|
|
59
|
+
stage3_max_reuse_distance: Optional[int] = 1e9
|
|
60
|
+
stage3_gather_16bit_weights_on_model_save: Optional[bool] = True
|
|
61
|
+
offload_optimizer: Optional[DsOffloadConfig] = None
|
|
62
|
+
offload_param: Optional[DsOffloadConfig] = None
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@dataclass(kw_only=True)
|
|
66
|
+
class DsFp16Config:
|
|
67
|
+
enabled: Union[str, bool] = 'auto'
|
|
68
|
+
loss_scale: int = 0
|
|
69
|
+
loss_scale_window: int = 1000
|
|
70
|
+
initial_scale_power: int = 16
|
|
71
|
+
hysteresis: int = 2
|
|
72
|
+
min_loss_scale: int = 1
|
|
73
|
+
fp16_opt_level: Optional[str] = 'O2'
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@dataclass(kw_only=True)
|
|
77
|
+
class DsBf16Config:
|
|
78
|
+
enabled: bool = True
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@dataclass(kw_only=True)
|
|
82
|
+
class DsConfig:
|
|
83
|
+
zero_config: Optional[DsZeROConfig] = field(default_factory=DsZero3Config)
|
|
84
|
+
fp16_config: Optional[DsFp16Config] = field(default_factory=DsFp16Config)
|
|
85
|
+
bf16_config: Optional[DsBf16Config] = field(default_factory=DsBf16Config)
|
|
86
|
+
gradient_clipping: Optional[float] = 1.0
|
|
87
|
+
activation_checkpointing: Optional[DsActivationCheckpointingConfig] = None
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@dataclass(kw_only=True)
|
|
91
|
+
class DataLoaderConfig:
|
|
92
|
+
"""
|
|
93
|
+
data loader配置项
|
|
94
|
+
Args:
|
|
95
|
+
data_loader_pin_memory (`bool`, *optional*, default is None):
|
|
96
|
+
data_loader pin_memory config
|
|
97
|
+
data_loader_num_workers (`int`, *optional*, default is 0):
|
|
98
|
+
data_loader num_workers config
|
|
99
|
+
data_loader_shuffle (`bool`, *optional*, default is False):
|
|
100
|
+
是否需要shuffle数据
|
|
101
|
+
data_loader_drop_last (`bool`, default is False):
|
|
102
|
+
最后一个batch不满足batch_size时,是否丢弃
|
|
103
|
+
"""
|
|
104
|
+
data_loader_pin_memory: bool = False
|
|
105
|
+
data_loader_num_workers: int = 0
|
|
106
|
+
data_loader_shuffle: bool = False
|
|
107
|
+
data_loader_drop_last: bool = True
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@dataclass(kw_only=True)
|
|
111
|
+
class OptimConfig:
|
|
112
|
+
optim_type: str = 'adam' # or 'lion'
|
|
113
|
+
enable_lr_scheduler: bool = False
|
|
114
|
+
initial_lr: float
|
|
115
|
+
weight_decay: Optional[float] = None
|
|
116
|
+
betas: Optional[Tuple[float, float]] = None
|
|
117
|
+
warmup_iters: Optional[int] = None
|
|
118
|
+
max_lr: Optional[float] = None
|
|
119
|
+
min_lr: Optional[float] = None
|
|
120
|
+
cosine_annealing_period: Optional[int] = None
|
|
121
|
+
cosine_annealing_period_mul: int = 0
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
@dataclass(kw_only=True)
|
|
125
|
+
class LossConfig:
|
|
126
|
+
critical_tokens: Optional[List[int]] = None
|
|
127
|
+
critical_alpha: float = 1.0
|
|
128
|
+
aux_loss_coef: Optional[float] = 0.001
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@dataclass(kw_only=True)
|
|
132
|
+
class KDConfig:
|
|
133
|
+
"""
|
|
134
|
+
知识蒸馏模式配置项
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
teacher_logits_provider (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
|
138
|
+
知识蒸馏教师模型logits的提供者
|
|
139
|
+
kd_coef (`float`, *optional*, default is 0.4):
|
|
140
|
+
蒸馏loss的占比,loss = kd_coef * kd_loss + (1 - kd_coef) * lm_loss
|
|
141
|
+
"""
|
|
142
|
+
teacher_logits_provider: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
|
|
143
|
+
kd_coef: float = 0.4
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
@dataclass(kw_only=True)
|
|
147
|
+
class EvalConfig:
|
|
148
|
+
"""
|
|
149
|
+
训练参数配置项
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
eval_batch_interval (`int`, default is 100):
|
|
153
|
+
每隔多少个batch进行模型eval
|
|
154
|
+
"""
|
|
155
|
+
max_new_tokens: int
|
|
156
|
+
eval_batch_interval: int = 100
|
|
157
|
+
temperature: float = 1.0
|
|
158
|
+
top_p: float = 0.95
|
|
159
|
+
top_k: Optional[float] = None
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
@dataclass(kw_only=True)
|
|
163
|
+
class PretrainConfig:
|
|
164
|
+
"""
|
|
165
|
+
训练参数配置项
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
gradient_accumulation_steps (`int`, *Optional*, default is 1):
|
|
169
|
+
梯度累积步数,为0时不使用梯度累积
|
|
170
|
+
目前仅适用于pretrain\sft\dpo,不适用于ppo\grpo\gspo
|
|
171
|
+
kd_config: (`KDConfig`, *Optional*, default is None):
|
|
172
|
+
知识蒸馏配置项,为None时不使用知识蒸馏
|
|
173
|
+
"""
|
|
174
|
+
gradient_accumulation_steps: int = 1
|
|
175
|
+
kd_config: Optional[KDConfig] = None
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
@dataclass(kw_only=True)
|
|
179
|
+
class SFTConfig:
|
|
180
|
+
"""
|
|
181
|
+
训练参数配置项
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
mask_prompt (`bool`)
|
|
185
|
+
指定是否mask prompt部分的token
|
|
186
|
+
gradient_accumulation_steps (`int`, *Optional*, default is 1):
|
|
187
|
+
梯度累积步数,为0时不使用梯度累积
|
|
188
|
+
目前仅适用于pretrain\sft\dpo,不适用于ppo\grpo\gspo
|
|
189
|
+
kd_config: (`KDConfig`, *Optional*, default is None):
|
|
190
|
+
知识蒸馏配置项,为None时不使用知识蒸馏
|
|
191
|
+
pixel_values_provider: (`Callable[[list[str]], torch.Tensor]`, *Optional*, default is None):
|
|
192
|
+
训练vlm时根据image_tag提供pixel_values信息
|
|
193
|
+
freeze_llm_model:
|
|
194
|
+
是否冻结llm部分model参数,用于训练vlm
|
|
195
|
+
"""
|
|
196
|
+
mask_prompt: bool = True
|
|
197
|
+
gradient_accumulation_steps: int = 1
|
|
198
|
+
kd_config: Optional[KDConfig] = None
|
|
199
|
+
image_tags_file_dataset: Optional[FileDataset] = None
|
|
200
|
+
pixel_values_provider: Optional[Callable[[list[str]], torch.Tensor]] = None
|
|
201
|
+
freeze_llm_model: bool = False
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
@dataclass(kw_only=True)
|
|
205
|
+
class DPOConfig:
|
|
206
|
+
"""
|
|
207
|
+
训练参数配置项
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
mask_prompt (`bool`)
|
|
211
|
+
指定是否mask prompt部分的token
|
|
212
|
+
gradient_accumulation_steps (`int`, *Optional*, default is 1):
|
|
213
|
+
梯度累积步数,为0时不使用梯度累积
|
|
214
|
+
目前仅适用于pretrain\sft\dpo,不适用于ppo\grpo\gspo
|
|
215
|
+
"""
|
|
216
|
+
ref_model_checkpoint: Mapping[str, Any]
|
|
217
|
+
mask_prompt: bool = True
|
|
218
|
+
gradient_accumulation_steps: int = 1
|
|
219
|
+
loss_beta: float
|
|
220
|
+
loss_label_smoothing: float = 0.0
|
|
221
|
+
loss_ipo: bool = False
|
|
222
|
+
nll_loss_coef: Optional[float] = None
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
@dataclass(kw_only=True)
|
|
226
|
+
class PPOConfig:
|
|
227
|
+
ppo_epochs: int
|
|
228
|
+
ppo_batch_size: int
|
|
229
|
+
ref_model_checkpoint: Mapping[str, Any]
|
|
230
|
+
value_model_checkpoint: Optional[Mapping[str, Any]] = None
|
|
231
|
+
gradient_accumulation_steps: int = 1
|
|
232
|
+
gamma: float = 1.0
|
|
233
|
+
lam: float = 0.95
|
|
234
|
+
clip_eps: float = 0.1
|
|
235
|
+
vf_coef: float = 0.5
|
|
236
|
+
kl_beta: float = 0.02
|
|
237
|
+
kl_estimator: str = 'k1' # or k3
|
|
238
|
+
missing_eos_penalty: Optional[float] = None
|
|
239
|
+
normalize_rewards: bool = False
|
|
240
|
+
whiten_rewards: bool = False
|
|
241
|
+
gen_max_new_tokens: int
|
|
242
|
+
gen_temperature: Optional[float] = None
|
|
243
|
+
gen_k: Optional[int] = None
|
|
244
|
+
gen_p: Optional[float] = None
|
|
245
|
+
gen_suppress_tokens: Optional[list[int]] = None
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
@dataclass(kw_only=True)
|
|
249
|
+
class GRPOConfig:
|
|
250
|
+
grpo_steps: int = 1
|
|
251
|
+
group_size: int = 12
|
|
252
|
+
mixup_alpha: float = 1.0
|
|
253
|
+
loss_beta: float = 0.0 # or 0.04 for grpo
|
|
254
|
+
loss_clip_eps: float = 3e-4
|
|
255
|
+
loss_clip_eps_high: Optional[float] = 4e-4
|
|
256
|
+
loss_delta: Optional[float] = None
|
|
257
|
+
loss_importance_sampling_level: str = 'seq' # token or seq
|
|
258
|
+
loss_type: str = 'grpo' # grpo or bnpo or dr_grpo
|
|
259
|
+
gen_max_new_tokens: int
|
|
260
|
+
gen_temperature: Optional[float] = None
|
|
261
|
+
gen_k: Optional[int] = None
|
|
262
|
+
gen_p: Optional[float] = None
|
|
263
|
+
gen_suppress_tokens: Optional[list[int]] = None
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
@dataclass(kw_only=True)
|
|
267
|
+
class TrainConfig:
|
|
268
|
+
"""
|
|
269
|
+
训练参数配置项
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
n_epochs (`int`):
|
|
273
|
+
训练epochs
|
|
274
|
+
batch_size (`int`):
|
|
275
|
+
每个batch的大小
|
|
276
|
+
model_config (`ModelConfig`):
|
|
277
|
+
模型的配置
|
|
278
|
+
init_state_dict:
|
|
279
|
+
初始化检查点
|
|
280
|
+
file_dataset (`FileDataset`):
|
|
281
|
+
训练文件dataset
|
|
282
|
+
max_seq_len (`int`, default is None)
|
|
283
|
+
训练序列最大长度,为None时取model的max_position_embedding
|
|
284
|
+
data_loader_config: (`DataLoaderConfig`):
|
|
285
|
+
data loader配置项
|
|
286
|
+
loss_config:
|
|
287
|
+
配置loss
|
|
288
|
+
ds_config:
|
|
289
|
+
配置deepspeed
|
|
290
|
+
eval_config:
|
|
291
|
+
配置eval
|
|
292
|
+
optim_config (`OptimConfig`):
|
|
293
|
+
optim配置项
|
|
294
|
+
pretrain_config:
|
|
295
|
+
预训练配置项,仅适用于使用Trainer
|
|
296
|
+
sft_config:
|
|
297
|
+
sft配置项,仅适用于使用SFTTrainer
|
|
298
|
+
dpo_config:
|
|
299
|
+
dpo配置项,仅适用于使用DPOTrainer
|
|
300
|
+
ppo_config:
|
|
301
|
+
ppo配置项,仅适用于使用PPOTrainer
|
|
302
|
+
grpo_config:
|
|
303
|
+
grpo配置项,仅适用于使用GRPOTrainer
|
|
304
|
+
"""
|
|
305
|
+
n_epochs: int
|
|
306
|
+
batch_size: int
|
|
307
|
+
model_config: Union[ModelConfig, VLMConfig]
|
|
308
|
+
init_state_dict: Optional[Mapping[str, Any]] = None
|
|
309
|
+
|
|
310
|
+
file_dataset: FileDataset
|
|
311
|
+
max_seq_len: int
|
|
312
|
+
data_loader_config: DataLoaderConfig = field(default_factory=DataLoaderConfig)
|
|
313
|
+
|
|
314
|
+
loss_config: LossConfig = field(default_factory=LossConfig)
|
|
315
|
+
optim_config: OptimConfig = field(default_factory=OptimConfig)
|
|
316
|
+
ds_config: DsConfig = field(default_factory=DsConfig)
|
|
317
|
+
|
|
318
|
+
eval_config: EvalConfig = field(default_factory=EvalConfig)
|
|
319
|
+
|
|
320
|
+
pretrain_config: Optional[PretrainConfig] = None
|
|
321
|
+
sft_config: Optional[SFTConfig] = None
|
|
322
|
+
dpo_config: Optional[DPOConfig] = None
|
|
323
|
+
ppo_config: Optional[PPOConfig] = None
|
|
324
|
+
grpo_config: Optional[GRPOConfig] = None
|