project-llm-trainer 0.12.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,97 @@
1
+ from typing import Optional, Tuple, List
2
+ from torch.utils.data import Dataset
3
+
4
+ from llm_model import (
5
+ VLMConfig,
6
+ LlmModel,
7
+ VlmModel
8
+ )
9
+
10
+ from .base_trainer import BaseTrainer
11
+ from .train_configs import TrainConfig
12
+ from .dataset import SFTDataset
13
+ from .utils import get_sft_collate_fn
14
+ from .tools import TrainerTools
15
+
16
+
17
+ class SFTTrainer(BaseTrainer):
18
+ def __init__(
19
+ self,
20
+ *,
21
+ train_config: TrainConfig,
22
+ eval_prompts: List[str],
23
+ eval_image_tags: Optional[List[str]] = None
24
+ ):
25
+ self.sft_config = train_config.sft_config
26
+ self.pixel_values_provider = self.sft_config.pixel_values_provider
27
+ self.eval_image_tags = eval_image_tags
28
+
29
+ super().__init__(
30
+ train_config=train_config,
31
+ eval_prompts=eval_prompts,
32
+ kd_config=self.sft_config.kd_config,
33
+ gradient_accumulation_steps=self.sft_config.gradient_accumulation_steps
34
+ )
35
+
36
+ if isinstance(train_config.model_config, VLMConfig):
37
+ self.pixel_values_provider = self.sft_config.pixel_values_provider
38
+ else:
39
+ self.pixel_values_provider = None
40
+
41
+ def _new_model(self, train_config: TrainConfig):
42
+ if isinstance(train_config.model_config, VLMConfig):
43
+ return VlmModel(train_config.model_config)
44
+ else:
45
+ return LlmModel(train_config.model_config)
46
+
47
+ def _check_freeze_llm_model(self, model):
48
+ # freeze llm model for vlm training
49
+ if self.sft_config.freeze_llm_model:
50
+ for name, param in model.named_parameters():
51
+ if not any(sub_module in name for sub_module in ['multi_modal_projector']):
52
+ param.requires_grad = False
53
+
54
+ # model.embed_tokens.eval()
55
+ # model.layers.eval()
56
+ # model.head_norm.eval()
57
+ # model.lm_head.eval()
58
+
59
+ def _convert_train_args(self) -> Tuple[dict, dict, dict]:
60
+ sft_collate_fn = get_sft_collate_fn(self.sft_config.mask_prompt)
61
+ parallel_kwargs, data_loader_kwargs, sampler_kwargs = super()._convert_train_args()
62
+ data_loader_kwargs.update({"collate_fn": sft_collate_fn})
63
+
64
+ return parallel_kwargs, data_loader_kwargs, sampler_kwargs
65
+
66
+ def _get_pixel_values(self, batch_data):
67
+ if self.pixel_values_provider and 'image_tags' in batch_data:
68
+ image_tags = batch_data['image_tags']
69
+ return self.pixel_values_provider(image_tags).to(TrainerTools().parallel.device)
70
+
71
+ return None
72
+
73
+ def _get_eval_pixel_values_and_tokens_count(self, eval_idx):
74
+ if not self.eval_image_tags:
75
+ return None, None
76
+
77
+ eval_image_tag = self.eval_image_tags[eval_idx]
78
+ if isinstance(self.train_config.model_config, VLMConfig) and self.pixel_values_provider and eval_image_tag:
79
+ return self.pixel_values_provider([eval_image_tag]), self.train_config.model_config.tokens_per_image
80
+
81
+ return None, None
82
+
83
+ def _create_dataset(self, file_idx) -> Tuple[Dataset, str]:
84
+ file_path = self.train_config.file_dataset[file_idx]
85
+ max_seq_len = self.train_config.max_seq_len
86
+
87
+ image_tag_file_path = None
88
+ tokens_per_image = -1
89
+
90
+ if isinstance(self.train_config.model_config, VLMConfig):
91
+ if self.sft_config.image_tags_file_dataset:
92
+ image_tag_file_path = self.sft_config.image_tags_file_dataset[file_idx]
93
+
94
+ if self.train_config.model_config.tokens_per_image:
95
+ tokens_per_image = self.train_config.model_config.tokens_per_image
96
+
97
+ return SFTDataset(file_path, max_seq_len, image_tag_file_path, tokens_per_image), file_path
@@ -0,0 +1,162 @@
1
+ import os
2
+ import warnings
3
+ from typing import List, Dict, Union
4
+ from transformers import AutoTokenizer
5
+ import torch
6
+
7
+
8
+ class Tokenizer:
9
+ def __init__(self):
10
+ self.tokenizer = AutoTokenizer.from_pretrained(os.environ['TOKEN_DIR'])
11
+
12
+ self.text_end = '</s>'
13
+
14
+ self.text_pad = '<pad>'
15
+ self.text_unk = '<unk>'
16
+
17
+ self.text_user = '<user>'
18
+ self.text_assistant = '<assistant>'
19
+
20
+ self.text_think_start = '<think>'
21
+ self.text_think_end = '</think>'
22
+
23
+ self.text_answer_start = '<answer>'
24
+ self.text_answer_end = '</answer>'
25
+
26
+ self.text_system = '<system>'
27
+
28
+ self.text_image = '<image>'
29
+
30
+ self.end = self.tokenizer.convert_tokens_to_ids(self.text_end)
31
+
32
+ self.pad = self.tokenizer.convert_tokens_to_ids(self.text_pad)
33
+ self.unk = self.tokenizer.convert_tokens_to_ids(self.text_unk)
34
+
35
+ self.user = self.tokenizer.convert_tokens_to_ids(self.text_user)
36
+ self.assistant = self.tokenizer.convert_tokens_to_ids(self.text_assistant)
37
+
38
+ self.think_start = self.tokenizer.convert_tokens_to_ids(self.text_think_start)
39
+ self.think_end = self.tokenizer.convert_tokens_to_ids(self.text_think_end)
40
+
41
+ self.answer_start = self.tokenizer.convert_tokens_to_ids(self.text_answer_start)
42
+ self.answer_end = self.tokenizer.convert_tokens_to_ids(self.text_answer_end)
43
+
44
+ self.system = self.tokenizer.convert_tokens_to_ids(self.text_system)
45
+ self.image = self.tokenizer.convert_tokens_to_ids(self.text_image)
46
+
47
+ self.vocab_size = len(self.tokenizer)
48
+
49
+ def encode(
50
+ self,
51
+ text: str,
52
+ unsqueeze: bool = False,
53
+ covert_tensor: bool = False
54
+ ) -> Union[torch.Tensor, List[int]]:
55
+ # [x,x,x]
56
+ encoded = self.tokenizer.encode(text, add_special_tokens=False)
57
+
58
+ if unsqueeze:
59
+ # tensor: [[x,x,x]]
60
+ return torch.tensor(encoded, dtype=torch.long).unsqueeze(0)
61
+ else:
62
+ # tensor: # [x,x,x]
63
+ if covert_tensor:
64
+ return torch.tensor(encoded, dtype=torch.long)
65
+
66
+ return encoded
67
+
68
+ def batch_encode(
69
+ self,
70
+ text: List[str],
71
+ padding = False,
72
+ truncation = False,
73
+ covert_tensor: bool = False,
74
+ return_attention_mask: bool = False
75
+ ) -> Union[torch.Tensor, List[List[int]]]:
76
+ encoded = self.tokenizer(
77
+ text,
78
+ padding=padding,
79
+ truncation=truncation,
80
+ return_attention_mask=return_attention_mask
81
+ )['input_ids']
82
+
83
+ if covert_tensor:
84
+ encoded = torch.tensor(encoded, dtype=torch.long)
85
+
86
+ return encoded
87
+
88
+ def decode(
89
+ self,
90
+ token: Union[torch.Tensor, List[int]],
91
+ skip_special_tokens: bool = False
92
+ ) -> str:
93
+ return self.tokenizer.decode(token, skip_special_tokens=skip_special_tokens)
94
+
95
+ def batch_decode(
96
+ self,
97
+ tokens: Union[torch.Tensor, List[int], List[List[int]]],
98
+ skip_special_tokens: bool = False
99
+ ) -> List[str]:
100
+ return self.tokenizer.batch_decode(tokens, skip_special_tokens=skip_special_tokens)
101
+
102
+ def encode_to_token(self, text: str, unsqueeze=True, covert_tensor=True):
103
+ warnings.warn('encode_to_token is deprecated. Please use `encode` instead.')
104
+ return self.encode(text, unsqueeze, covert_tensor)
105
+
106
+ def decode_to_text(self, token: torch.Tensor, skip_special_tokens: bool = False) -> str:
107
+ warnings.warn('decode_to_text is deprecated. Please use `decode` instead.')
108
+ return self.decode(token.squeeze(0), skip_special_tokens)
109
+
110
+ def apply_chat_template(
111
+ self,
112
+ conversations: List[Dict[str, str]],
113
+ tokenizer: bool = True,
114
+ add_answer_tag_for_assistant: bool = True,
115
+ unsqueeze=False,
116
+ covert_tensor=False
117
+ ):
118
+ """
119
+ [
120
+ {"role":"system", "content":"system prompt"},
121
+ {"role":"user", "content":"hello?"},
122
+ {"role":"assistant", "content":"hello"},
123
+ {"role":"user", "content":"hello hello?"},
124
+ {"role":"assistant", "think":"thinking", "content":"hello hello"},
125
+ ]
126
+ <system>{system_prompt}</s><user>hello?</s><assistant>hello</s><user>hello hello?</s><assistant><think>thinking</think><answer>hello hello</answer></s>
127
+ """
128
+
129
+ chat_template = ''
130
+ support_roles = {'system': self.text_system, 'user': self.text_user, 'assistant': self.text_assistant}
131
+ for conversation in conversations:
132
+ role = conversation['role']
133
+ if role in support_roles:
134
+ content = conversation['content']
135
+ if add_answer_tag_for_assistant and role == 'assistant':
136
+ content = f"{self.text_answer_start}{content}{self.text_answer_end}"
137
+
138
+ if 'think' in conversation:
139
+ content = f"{self.text_think_start}{conversation['think']}{self.text_think_end}{content}"
140
+
141
+ chat_template = f"{chat_template}{support_roles[role]}{content}{self.text_end}"
142
+
143
+ if tokenizer:
144
+ return self.encode(chat_template, unsqueeze, covert_tensor)
145
+
146
+ return chat_template
147
+
148
+ def get_special_tokens_dict(self):
149
+ return {
150
+ self.text_end: self.end,
151
+ self.text_pad: self.pad,
152
+ self.text_unk: self.unk,
153
+ self.text_user: self.user,
154
+ self.text_assistant: self.assistant,
155
+ self.text_think_start: self.think_start,
156
+ self.text_think_end: self.think_end,
157
+ self.text_answer_start: self.answer_start,
158
+ self.text_answer_end: self.answer_end,
159
+ self.text_system: self.system,
160
+ self.text_image: self.image,
161
+ }
162
+
llm_trainer/tools.py ADDED
@@ -0,0 +1,116 @@
1
+ import os
2
+ from abc import ABC, abstractmethod
3
+ import torch
4
+ from .tokenizer import Tokenizer
5
+ from .parallel import DsParallel, DdpParallel, NoneParallel
6
+ from .log import Logger
7
+
8
+
9
+ parallel_types = {
10
+ 'ds': DsParallel,
11
+ 'ddp': DdpParallel,
12
+ 'none': NoneParallel
13
+ }
14
+
15
+ dtypes = {
16
+ 'float': torch.float,
17
+ 'float16': torch.float16,
18
+ 'float32': torch.float32,
19
+ 'float64': torch.float64
20
+ }
21
+
22
+ class TrainerTools:
23
+ def __init__(self):
24
+ if not hasattr(TrainerTools, "_first_init"):
25
+ TrainerTools._first_init = True
26
+
27
+ self.parallel = self._new_parallel()
28
+
29
+ self.tokenizer = Tokenizer()
30
+ self.use_amp = 'cuda' in self.parallel.device and not isinstance(self.parallel, DsParallel)
31
+
32
+ Logger.std_log(f'word_size={self.parallel.world_size}, use_amp={self.use_amp}')
33
+
34
+ def _new_parallel(self):
35
+ parallel_type = os.environ.get('PARALLEL_TYPE', 'none')
36
+ Logger.std_log(f'parallel_type={parallel_type}')
37
+ return parallel_types[parallel_type]()
38
+
39
+ def __new__(cls, *args, **kwargs):
40
+ if not hasattr(TrainerTools, "_instance"):
41
+ TrainerTools._instance = object.__new__(cls)
42
+
43
+ return TrainerTools._instance
44
+
45
+
46
+ class FileDataset(ABC):
47
+ @abstractmethod
48
+ def __len__(self) -> int: ...
49
+
50
+ @abstractmethod
51
+ def __getitem__(self, idx) -> str: ...
52
+
53
+
54
+ def estimate_data_size(
55
+ file_dataset: FileDataset,
56
+ max_seq_len: int,
57
+ type: str
58
+ ) -> int:
59
+ """
60
+ 估计数据集大小
61
+ """
62
+ data_size = 0
63
+ files_count = len(file_dataset)
64
+
65
+ if type == 'sft':
66
+ from .dataset import SFTDataset
67
+ for idx in range(files_count):
68
+ dataset = SFTDataset(file_dataset[idx], max_seq_len)
69
+ data_size += len(dataset)
70
+ elif type == 'dpo':
71
+ from .dataset import DPODataset
72
+ for idx in range(files_count):
73
+ dataset = DPODataset(file_dataset[idx], max_seq_len)
74
+ data_size += len(dataset)
75
+ elif type == 'grpo' or type == 'ppo':
76
+ from .dataset import RLDataset
77
+ for idx in range(files_count):
78
+ dataset = RLDataset(file_dataset[idx])
79
+ data_size += len(dataset)
80
+ else:
81
+ from .dataset import PretrainDataset
82
+ for idx in range(files_count):
83
+ dataset = PretrainDataset(
84
+ file_dataset[idx],
85
+ max_seq_len,
86
+ max_seq_len
87
+ )
88
+ data_size += len(dataset)
89
+
90
+ return data_size
91
+
92
+
93
+ def extract_policy_weights_from_ppo(model_config, ppo_weights):
94
+ from llm_model import LlmModel
95
+ from .ppo_trainer import PolicyAndValueModelWrapper, ValueModel
96
+
97
+ policy_model = LlmModel(model_config)
98
+ value_model = ValueModel(LlmModel(model_config))
99
+
100
+ wrapper = PolicyAndValueModelWrapper(policy_model, value_model)
101
+ wrapper.load_state_dict(ppo_weights)
102
+
103
+ return wrapper.policy_model.state_dict()
104
+
105
+
106
+ def extract_value_weights_from_ppo(model_config, ppo_weights):
107
+ from llm_model import LlmModel
108
+ from .ppo_trainer import PolicyAndValueModelWrapper, ValueModel
109
+
110
+ policy_model = LlmModel(model_config)
111
+ value_model = ValueModel(LlmModel(model_config))
112
+
113
+ wrapper = PolicyAndValueModelWrapper(policy_model, value_model)
114
+ wrapper.load_state_dict(ppo_weights)
115
+
116
+ return wrapper.value_model.state_dict()
@@ -0,0 +1,324 @@
1
+ from typing import Optional, Union, Callable, List, Mapping, Any, Tuple
2
+ from dataclasses import dataclass, field
3
+
4
+ import torch
5
+ from llm_model import ModelConfig, VLMConfig
6
+ from .tools import FileDataset
7
+
8
+
9
+ @dataclass(kw_only=True)
10
+ class DsOffloadConfig:
11
+ device: str = 'cpu'
12
+ pin_memory: bool = True
13
+
14
+
15
+ @dataclass(kw_only=True)
16
+ class DsActivationCheckpointingConfig:
17
+ partition_activations: bool = True
18
+ cpu_checkpointing: bool = False
19
+ contiguous_memory_optimization: bool = True
20
+ number_checkpoints: Optional[int] = None
21
+ synchronize_checkpoint_boundary: bool = False
22
+ profile: bool = False
23
+
24
+
25
+ @dataclass(kw_only=True)
26
+ class DsZeROConfig:
27
+ stage: int
28
+ allgather_partitions: Optional[bool] = True
29
+ allgather_bucket_size: Optional[int] = 5e8
30
+ overlap_comm: Optional[bool] = True
31
+ reduce_scatter: Optional[bool] = True
32
+ reduce_bucket_size: Optional[Union[str, int]] = 5e8
33
+ contiguous_gradients: Optional[bool] = True
34
+
35
+ @dataclass(kw_only=True)
36
+ class DsZero0Config(DsZeROConfig):
37
+ stage: int = field(default=0, init=False)
38
+
39
+
40
+ @dataclass(kw_only=True)
41
+ class DsZero1Config(DsZeROConfig):
42
+ stage: int = field(default=1, init=False)
43
+
44
+
45
+ @dataclass(kw_only=True)
46
+ class DsZero2Config(DsZeROConfig):
47
+ stage: int = field(default=2, init=False)
48
+ offload_optimizer: Optional[DsOffloadConfig] = None
49
+ offload_param: Optional[DsOffloadConfig] = None
50
+
51
+
52
+ @dataclass(kw_only=True)
53
+ class DsZero3Config(DsZeROConfig):
54
+ stage: int = field(default=3, init=False)
55
+ sub_group_size: Optional[int] = 1e9
56
+ stage3_prefetch_bucket_size: Optional[Union[str, int]] = 'auto'
57
+ stage3_param_persistence_threshold: Optional[Union[str, int]] = 'auto'
58
+ stage3_max_live_parameters: Optional[int] = 1e9
59
+ stage3_max_reuse_distance: Optional[int] = 1e9
60
+ stage3_gather_16bit_weights_on_model_save: Optional[bool] = True
61
+ offload_optimizer: Optional[DsOffloadConfig] = None
62
+ offload_param: Optional[DsOffloadConfig] = None
63
+
64
+
65
+ @dataclass(kw_only=True)
66
+ class DsFp16Config:
67
+ enabled: Union[str, bool] = 'auto'
68
+ loss_scale: int = 0
69
+ loss_scale_window: int = 1000
70
+ initial_scale_power: int = 16
71
+ hysteresis: int = 2
72
+ min_loss_scale: int = 1
73
+ fp16_opt_level: Optional[str] = 'O2'
74
+
75
+
76
+ @dataclass(kw_only=True)
77
+ class DsBf16Config:
78
+ enabled: bool = True
79
+
80
+
81
+ @dataclass(kw_only=True)
82
+ class DsConfig:
83
+ zero_config: Optional[DsZeROConfig] = field(default_factory=DsZero3Config)
84
+ fp16_config: Optional[DsFp16Config] = field(default_factory=DsFp16Config)
85
+ bf16_config: Optional[DsBf16Config] = field(default_factory=DsBf16Config)
86
+ gradient_clipping: Optional[float] = 1.0
87
+ activation_checkpointing: Optional[DsActivationCheckpointingConfig] = None
88
+
89
+
90
+ @dataclass(kw_only=True)
91
+ class DataLoaderConfig:
92
+ """
93
+ data loader配置项
94
+ Args:
95
+ data_loader_pin_memory (`bool`, *optional*, default is None):
96
+ data_loader pin_memory config
97
+ data_loader_num_workers (`int`, *optional*, default is 0):
98
+ data_loader num_workers config
99
+ data_loader_shuffle (`bool`, *optional*, default is False):
100
+ 是否需要shuffle数据
101
+ data_loader_drop_last (`bool`, default is False):
102
+ 最后一个batch不满足batch_size时,是否丢弃
103
+ """
104
+ data_loader_pin_memory: bool = False
105
+ data_loader_num_workers: int = 0
106
+ data_loader_shuffle: bool = False
107
+ data_loader_drop_last: bool = True
108
+
109
+
110
+ @dataclass(kw_only=True)
111
+ class OptimConfig:
112
+ optim_type: str = 'adam' # or 'lion'
113
+ enable_lr_scheduler: bool = False
114
+ initial_lr: float
115
+ weight_decay: Optional[float] = None
116
+ betas: Optional[Tuple[float, float]] = None
117
+ warmup_iters: Optional[int] = None
118
+ max_lr: Optional[float] = None
119
+ min_lr: Optional[float] = None
120
+ cosine_annealing_period: Optional[int] = None
121
+ cosine_annealing_period_mul: int = 0
122
+
123
+
124
+ @dataclass(kw_only=True)
125
+ class LossConfig:
126
+ critical_tokens: Optional[List[int]] = None
127
+ critical_alpha: float = 1.0
128
+ aux_loss_coef: Optional[float] = 0.001
129
+
130
+
131
+ @dataclass(kw_only=True)
132
+ class KDConfig:
133
+ """
134
+ 知识蒸馏模式配置项
135
+
136
+ Args:
137
+ teacher_logits_provider (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
138
+ 知识蒸馏教师模型logits的提供者
139
+ kd_coef (`float`, *optional*, default is 0.4):
140
+ 蒸馏loss的占比,loss = kd_coef * kd_loss + (1 - kd_coef) * lm_loss
141
+ """
142
+ teacher_logits_provider: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
143
+ kd_coef: float = 0.4
144
+
145
+
146
+ @dataclass(kw_only=True)
147
+ class EvalConfig:
148
+ """
149
+ 训练参数配置项
150
+
151
+ Args:
152
+ eval_batch_interval (`int`, default is 100):
153
+ 每隔多少个batch进行模型eval
154
+ """
155
+ max_new_tokens: int
156
+ eval_batch_interval: int = 100
157
+ temperature: float = 1.0
158
+ top_p: float = 0.95
159
+ top_k: Optional[float] = None
160
+
161
+
162
+ @dataclass(kw_only=True)
163
+ class PretrainConfig:
164
+ """
165
+ 训练参数配置项
166
+
167
+ Args:
168
+ gradient_accumulation_steps (`int`, *Optional*, default is 1):
169
+ 梯度累积步数,为0时不使用梯度累积
170
+ 目前仅适用于pretrain\sft\dpo,不适用于ppo\grpo\gspo
171
+ kd_config: (`KDConfig`, *Optional*, default is None):
172
+ 知识蒸馏配置项,为None时不使用知识蒸馏
173
+ """
174
+ gradient_accumulation_steps: int = 1
175
+ kd_config: Optional[KDConfig] = None
176
+
177
+
178
+ @dataclass(kw_only=True)
179
+ class SFTConfig:
180
+ """
181
+ 训练参数配置项
182
+
183
+ Args:
184
+ mask_prompt (`bool`)
185
+ 指定是否mask prompt部分的token
186
+ gradient_accumulation_steps (`int`, *Optional*, default is 1):
187
+ 梯度累积步数,为0时不使用梯度累积
188
+ 目前仅适用于pretrain\sft\dpo,不适用于ppo\grpo\gspo
189
+ kd_config: (`KDConfig`, *Optional*, default is None):
190
+ 知识蒸馏配置项,为None时不使用知识蒸馏
191
+ pixel_values_provider: (`Callable[[list[str]], torch.Tensor]`, *Optional*, default is None):
192
+ 训练vlm时根据image_tag提供pixel_values信息
193
+ freeze_llm_model:
194
+ 是否冻结llm部分model参数,用于训练vlm
195
+ """
196
+ mask_prompt: bool = True
197
+ gradient_accumulation_steps: int = 1
198
+ kd_config: Optional[KDConfig] = None
199
+ image_tags_file_dataset: Optional[FileDataset] = None
200
+ pixel_values_provider: Optional[Callable[[list[str]], torch.Tensor]] = None
201
+ freeze_llm_model: bool = False
202
+
203
+
204
+ @dataclass(kw_only=True)
205
+ class DPOConfig:
206
+ """
207
+ 训练参数配置项
208
+
209
+ Args:
210
+ mask_prompt (`bool`)
211
+ 指定是否mask prompt部分的token
212
+ gradient_accumulation_steps (`int`, *Optional*, default is 1):
213
+ 梯度累积步数,为0时不使用梯度累积
214
+ 目前仅适用于pretrain\sft\dpo,不适用于ppo\grpo\gspo
215
+ """
216
+ ref_model_checkpoint: Mapping[str, Any]
217
+ mask_prompt: bool = True
218
+ gradient_accumulation_steps: int = 1
219
+ loss_beta: float
220
+ loss_label_smoothing: float = 0.0
221
+ loss_ipo: bool = False
222
+ nll_loss_coef: Optional[float] = None
223
+
224
+
225
+ @dataclass(kw_only=True)
226
+ class PPOConfig:
227
+ ppo_epochs: int
228
+ ppo_batch_size: int
229
+ ref_model_checkpoint: Mapping[str, Any]
230
+ value_model_checkpoint: Optional[Mapping[str, Any]] = None
231
+ gradient_accumulation_steps: int = 1
232
+ gamma: float = 1.0
233
+ lam: float = 0.95
234
+ clip_eps: float = 0.1
235
+ vf_coef: float = 0.5
236
+ kl_beta: float = 0.02
237
+ kl_estimator: str = 'k1' # or k3
238
+ missing_eos_penalty: Optional[float] = None
239
+ normalize_rewards: bool = False
240
+ whiten_rewards: bool = False
241
+ gen_max_new_tokens: int
242
+ gen_temperature: Optional[float] = None
243
+ gen_k: Optional[int] = None
244
+ gen_p: Optional[float] = None
245
+ gen_suppress_tokens: Optional[list[int]] = None
246
+
247
+
248
+ @dataclass(kw_only=True)
249
+ class GRPOConfig:
250
+ grpo_steps: int = 1
251
+ group_size: int = 12
252
+ mixup_alpha: float = 1.0
253
+ loss_beta: float = 0.0 # or 0.04 for grpo
254
+ loss_clip_eps: float = 3e-4
255
+ loss_clip_eps_high: Optional[float] = 4e-4
256
+ loss_delta: Optional[float] = None
257
+ loss_importance_sampling_level: str = 'seq' # token or seq
258
+ loss_type: str = 'grpo' # grpo or bnpo or dr_grpo
259
+ gen_max_new_tokens: int
260
+ gen_temperature: Optional[float] = None
261
+ gen_k: Optional[int] = None
262
+ gen_p: Optional[float] = None
263
+ gen_suppress_tokens: Optional[list[int]] = None
264
+
265
+
266
+ @dataclass(kw_only=True)
267
+ class TrainConfig:
268
+ """
269
+ 训练参数配置项
270
+
271
+ Args:
272
+ n_epochs (`int`):
273
+ 训练epochs
274
+ batch_size (`int`):
275
+ 每个batch的大小
276
+ model_config (`ModelConfig`):
277
+ 模型的配置
278
+ init_state_dict:
279
+ 初始化检查点
280
+ file_dataset (`FileDataset`):
281
+ 训练文件dataset
282
+ max_seq_len (`int`, default is None)
283
+ 训练序列最大长度,为None时取model的max_position_embedding
284
+ data_loader_config: (`DataLoaderConfig`):
285
+ data loader配置项
286
+ loss_config:
287
+ 配置loss
288
+ ds_config:
289
+ 配置deepspeed
290
+ eval_config:
291
+ 配置eval
292
+ optim_config (`OptimConfig`):
293
+ optim配置项
294
+ pretrain_config:
295
+ 预训练配置项,仅适用于使用Trainer
296
+ sft_config:
297
+ sft配置项,仅适用于使用SFTTrainer
298
+ dpo_config:
299
+ dpo配置项,仅适用于使用DPOTrainer
300
+ ppo_config:
301
+ ppo配置项,仅适用于使用PPOTrainer
302
+ grpo_config:
303
+ grpo配置项,仅适用于使用GRPOTrainer
304
+ """
305
+ n_epochs: int
306
+ batch_size: int
307
+ model_config: Union[ModelConfig, VLMConfig]
308
+ init_state_dict: Optional[Mapping[str, Any]] = None
309
+
310
+ file_dataset: FileDataset
311
+ max_seq_len: int
312
+ data_loader_config: DataLoaderConfig = field(default_factory=DataLoaderConfig)
313
+
314
+ loss_config: LossConfig = field(default_factory=LossConfig)
315
+ optim_config: OptimConfig = field(default_factory=OptimConfig)
316
+ ds_config: DsConfig = field(default_factory=DsConfig)
317
+
318
+ eval_config: EvalConfig = field(default_factory=EvalConfig)
319
+
320
+ pretrain_config: Optional[PretrainConfig] = None
321
+ sft_config: Optional[SFTConfig] = None
322
+ dpo_config: Optional[DPOConfig] = None
323
+ ppo_config: Optional[PPOConfig] = None
324
+ grpo_config: Optional[GRPOConfig] = None