project-llm-trainer 0.13.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

@@ -0,0 +1,220 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import List, Optional
3
+ import math
4
+ import torch
5
+ from .log import Logger
6
+
7
+ class LRScheduler(ABC):
8
+ @property
9
+ @abstractmethod
10
+ def cur_steps(self): ...
11
+
12
+ @property
13
+ @abstractmethod
14
+ def cur_lr(self): ...
15
+
16
+ @abstractmethod
17
+ def step(self): ...
18
+
19
+ @abstractmethod
20
+ def can_clip_grad(self): ...
21
+
22
+ @abstractmethod
23
+ def get_ckpt_dict(self) -> dict: ...
24
+
25
+ @abstractmethod
26
+ def restore_ckpt_dict(self, ckpt: dict): ...
27
+
28
+
29
+ class WarmupCosineAnnealingLRScheduler(LRScheduler):
30
+ def __init__(
31
+ self,
32
+ *,
33
+ optimizer: torch.optim.Optimizer,
34
+ warmup_iters: int,
35
+ initial_lr: float,
36
+ min_lr: float,
37
+ max_lr: float,
38
+ cosine_annealing_period: int, # 每个周期的步数
39
+ cosine_annealing_period_mul: int = 0, # 周期长度的倍数
40
+ param_group_indices: Optional[List[int]] = None,
41
+ need_log: bool = False
42
+ ):
43
+ super().__init__()
44
+
45
+ self._optimizer = optimizer
46
+ self._initial_lr = initial_lr
47
+ self._min_lr = min_lr
48
+ self._max_lr = max_lr
49
+ self._warmup_iters = warmup_iters
50
+
51
+ self._cosine_annealing_period = cosine_annealing_period
52
+ self._cosine_annealing_period_mul = cosine_annealing_period_mul
53
+
54
+ self.param_group_indices = param_group_indices
55
+
56
+ self.T_cur = 0 # 当前周期内已走过的步数
57
+ self.cycle = 0 # 当前周期编号
58
+
59
+ if warmup_iters != 0:
60
+ self._lr_increment = (max_lr - initial_lr) / warmup_iters
61
+ else:
62
+ self._lr_increment = 0
63
+
64
+ self._steps = -1
65
+ self._current_lr = initial_lr
66
+ self._cosine_annealing_base_lr = None
67
+
68
+ if need_log:
69
+ self.logger = Logger('lr.txt')
70
+ else:
71
+ self.logger = None
72
+
73
+ @property
74
+ def cur_steps(self):
75
+ return self._steps
76
+
77
+ @property
78
+ def cur_lr(self):
79
+ return self._current_lr
80
+
81
+ def step(self):
82
+ self._steps += 1
83
+ self._update_lr()
84
+
85
+ def can_clip_grad(self):
86
+ return self._steps > self._warmup_iters
87
+
88
+ def _update_lr(self):
89
+ if self.param_group_indices is None:
90
+ target_groups = self._optimizer.param_groups
91
+ else:
92
+ target_groups = [self._optimizer.param_groups[i] for i in self.param_group_indices]
93
+
94
+ # 如果period_mul是0,则认为没有周期,超过余弦退火总步数,则一直保持最小lr
95
+ if self._cosine_annealing_period_mul == 0 and self._steps >= self._cosine_annealing_period + self._warmup_iters:
96
+ lr = self._min_lr
97
+ for param_group in target_groups:
98
+ param_group['lr'] = lr
99
+ elif self._steps <= self._warmup_iters:
100
+ # Warmup: adjust learning rate linearly
101
+ # (max_lr - initial_lr) / warmup_iters
102
+ lr = self._initial_lr + self._steps * self._lr_increment
103
+ for param_group in target_groups:
104
+ param_group['lr'] = lr
105
+ else:
106
+ if not self._cosine_annealing_base_lr:
107
+ self._cosine_annealing_base_lr = self.cur_lr
108
+
109
+ """每步更新学习率"""
110
+ # 计算当前周期的最大步数
111
+ T_max = self._cosine_annealing_period * (max(self._cosine_annealing_period_mul, 1) ** self.cycle)
112
+
113
+ # 更新周期状态
114
+ self.T_cur += 1
115
+ calc_t = self.T_cur
116
+
117
+ if self.T_cur >= T_max:
118
+ if self._cosine_annealing_period_mul == 0:
119
+ self.T_cur = T_max
120
+ calc_t = T_max
121
+ else:
122
+ self.cycle += 1
123
+ self.T_cur = 0
124
+ calc_t = T_max
125
+
126
+ # 计算并设置新学习率
127
+ cos_factor = (1 + math.cos(math.pi * calc_t / T_max)) / 2
128
+ lr = self._min_lr + (self._cosine_annealing_base_lr - self._min_lr) * cos_factor
129
+
130
+ for param_group in target_groups:
131
+ param_group['lr'] = lr
132
+
133
+ self._current_lr = lr
134
+
135
+ if self.logger:
136
+ self.logger.log(f"step: {self.cur_steps}, lr: {lr}", log_to_console=False)
137
+
138
+ def get_ckpt_dict(self) -> dict:
139
+ return {
140
+ 'cur_lr': self._current_lr,
141
+ 'lr_steps': self.cur_steps,
142
+ 'cosine_annealing_base_lr': self._cosine_annealing_base_lr,
143
+ 't_cur': self.T_cur,
144
+ 'cycle': self.cycle,
145
+ }
146
+
147
+ def restore_ckpt_dict(self, ckpt: dict):
148
+ if 'cur_lr' in ckpt:
149
+ self._current_lr = ckpt['cur_lr']
150
+
151
+ if 'lr_steps' in ckpt:
152
+ self._steps = ckpt['lr_steps']
153
+
154
+ if 'cosine_annealing_base_lr' in ckpt:
155
+ self._cosine_annealing_base_lr = ckpt['cosine_annealing_base_lr']
156
+
157
+ if 't_cur' in ckpt:
158
+ self.T_cur = ckpt['t_cur']
159
+
160
+ if 'cycle' in ckpt:
161
+ self.cycle = ckpt['cycle']
162
+
163
+ self._update_lr()
164
+
165
+
166
+ class NoneLRScheduler(LRScheduler):
167
+ def __init__(self, initial_lr):
168
+ self._current_lr = initial_lr
169
+
170
+ @property
171
+ def cur_steps(self):
172
+ return -1
173
+
174
+ @property
175
+ def cur_lr(self):
176
+ return self._current_lr
177
+
178
+ def step(self): ...
179
+
180
+ def can_clip_grad(self):
181
+ return True
182
+
183
+ def get_ckpt_dict(self) -> dict:
184
+ return {'cur_lr': self._current_lr}
185
+
186
+ def restore_ckpt_dict(self, ckpt: dict):
187
+ if 'cur_lr' in ckpt:
188
+ self._current_lr = ckpt['cur_lr']
189
+
190
+
191
+ class CompositeLRScheduler(LRScheduler):
192
+ def __init__(self, schedulers: List[LRScheduler]):
193
+ self.schedulers = schedulers
194
+
195
+ @property
196
+ def cur_steps(self):
197
+ return self.schedulers[0].cur_steps if self.schedulers else 0
198
+
199
+ @property
200
+ def cur_lr(self):
201
+ return self.schedulers[0].cur_lr if self.schedulers else 0.0
202
+
203
+ def step(self):
204
+ for scheduler in self.schedulers:
205
+ scheduler.step()
206
+
207
+ def can_clip_grad(self):
208
+ return all(s.can_clip_grad() for s in self.schedulers)
209
+
210
+ def get_ckpt_dict(self) -> dict:
211
+ ckpt = {}
212
+ for i, scheduler in enumerate(self.schedulers):
213
+ ckpt[f'scheduler_{i}'] = scheduler.get_ckpt_dict()
214
+ return ckpt
215
+
216
+ def restore_ckpt_dict(self, ckpt: dict):
217
+ for i, scheduler in enumerate(self.schedulers):
218
+ key = f'scheduler_{i}'
219
+ if key in ckpt:
220
+ scheduler.restore_ckpt_dict(ckpt[key])
@@ -0,0 +1,97 @@
1
+ from typing import Optional, Tuple, List
2
+ from torch.utils.data import Dataset
3
+
4
+ from llm_model import (
5
+ VLMConfig,
6
+ LlmModel,
7
+ VlmModel
8
+ )
9
+
10
+ from .base_trainer import BaseTrainer
11
+ from .train_configs import TrainConfig
12
+ from .dataset import SFTDataset
13
+ from .utils import get_sft_collate_fn
14
+ from .tools import TrainerTools
15
+
16
+
17
+ class SFTTrainer(BaseTrainer):
18
+ def __init__(
19
+ self,
20
+ *,
21
+ train_config: TrainConfig,
22
+ eval_prompts: List[str],
23
+ eval_image_tags: Optional[List[str]] = None
24
+ ):
25
+ self.sft_config = train_config.sft_config
26
+ self.pixel_values_provider = self.sft_config.pixel_values_provider
27
+ self.eval_image_tags = eval_image_tags
28
+
29
+ super().__init__(
30
+ train_config=train_config,
31
+ eval_prompts=eval_prompts,
32
+ kd_config=self.sft_config.kd_config,
33
+ gradient_accumulation_steps=self.sft_config.gradient_accumulation_steps
34
+ )
35
+
36
+ if isinstance(train_config.model_config, VLMConfig):
37
+ self.pixel_values_provider = self.sft_config.pixel_values_provider
38
+ else:
39
+ self.pixel_values_provider = None
40
+
41
+ def _new_model(self, train_config: TrainConfig):
42
+ if isinstance(train_config.model_config, VLMConfig):
43
+ return VlmModel(train_config.model_config)
44
+ else:
45
+ return LlmModel(train_config.model_config)
46
+
47
+ def _check_freeze_llm_model(self, model):
48
+ # freeze llm model for vlm training
49
+ if self.sft_config.freeze_llm_model:
50
+ for name, param in model.named_parameters():
51
+ if not any(sub_module in name for sub_module in ['multi_modal_projector']):
52
+ param.requires_grad = False
53
+
54
+ # model.embed_tokens.eval()
55
+ # model.layers.eval()
56
+ # model.head_norm.eval()
57
+ # model.lm_head.eval()
58
+
59
+ def _convert_train_args(self) -> Tuple[dict, dict, dict]:
60
+ sft_collate_fn = get_sft_collate_fn(self.sft_config.mask_prompt)
61
+ parallel_kwargs, data_loader_kwargs, sampler_kwargs = super()._convert_train_args()
62
+ data_loader_kwargs.update({"collate_fn": sft_collate_fn})
63
+
64
+ return parallel_kwargs, data_loader_kwargs, sampler_kwargs
65
+
66
+ def _get_pixel_values(self, batch_data):
67
+ if self.pixel_values_provider and 'image_tags' in batch_data:
68
+ image_tags = batch_data['image_tags']
69
+ return self.pixel_values_provider(image_tags).to(TrainerTools().parallel.device)
70
+
71
+ return None
72
+
73
+ def _get_eval_pixel_values_and_tokens_count(self, eval_idx):
74
+ if not self.eval_image_tags:
75
+ return None, None
76
+
77
+ eval_image_tag = self.eval_image_tags[eval_idx]
78
+ if isinstance(self.train_config.model_config, VLMConfig) and self.pixel_values_provider and eval_image_tag:
79
+ return self.pixel_values_provider([eval_image_tag]), self.train_config.model_config.tokens_per_image
80
+
81
+ return None, None
82
+
83
+ def _create_dataset(self, file_idx) -> Tuple[Dataset, str]:
84
+ file_path = self.train_config.file_dataset[file_idx]
85
+ block_size = self.train_config.dataset_block_size
86
+
87
+ image_tag_file_path = None
88
+ tokens_per_image = -1
89
+
90
+ if isinstance(self.train_config.model_config, VLMConfig):
91
+ if self.sft_config.image_tags_file_dataset:
92
+ image_tag_file_path = self.sft_config.image_tags_file_dataset[file_idx]
93
+
94
+ if self.train_config.model_config.tokens_per_image:
95
+ tokens_per_image = self.train_config.model_config.tokens_per_image
96
+
97
+ return SFTDataset(file_path, block_size, image_tag_file_path, tokens_per_image), file_path
@@ -0,0 +1,162 @@
1
+ import os
2
+ import warnings
3
+ from typing import List, Dict, Union
4
+ from transformers import AutoTokenizer
5
+ import torch
6
+
7
+
8
+ class Tokenizer:
9
+ def __init__(self):
10
+ self.tokenizer = AutoTokenizer.from_pretrained(os.environ['TOKEN_DIR'])
11
+
12
+ self.text_end = '</s>'
13
+
14
+ self.text_pad = '<pad>'
15
+ self.text_unk = '<unk>'
16
+
17
+ self.text_user = '<user>'
18
+ self.text_assistant = '<assistant>'
19
+
20
+ self.text_think_start = '<think>'
21
+ self.text_think_end = '</think>'
22
+
23
+ self.text_answer_start = '<answer>'
24
+ self.text_answer_end = '</answer>'
25
+
26
+ self.text_system = '<system>'
27
+
28
+ self.text_image = '<image>'
29
+
30
+ self.end = self.tokenizer.convert_tokens_to_ids(self.text_end)
31
+
32
+ self.pad = self.tokenizer.convert_tokens_to_ids(self.text_pad)
33
+ self.unk = self.tokenizer.convert_tokens_to_ids(self.text_unk)
34
+
35
+ self.user = self.tokenizer.convert_tokens_to_ids(self.text_user)
36
+ self.assistant = self.tokenizer.convert_tokens_to_ids(self.text_assistant)
37
+
38
+ self.think_start = self.tokenizer.convert_tokens_to_ids(self.text_think_start)
39
+ self.think_end = self.tokenizer.convert_tokens_to_ids(self.text_think_end)
40
+
41
+ self.answer_start = self.tokenizer.convert_tokens_to_ids(self.text_answer_start)
42
+ self.answer_end = self.tokenizer.convert_tokens_to_ids(self.text_answer_end)
43
+
44
+ self.system = self.tokenizer.convert_tokens_to_ids(self.text_system)
45
+ self.image = self.tokenizer.convert_tokens_to_ids(self.text_image)
46
+
47
+ self.vocab_size = len(self.tokenizer)
48
+
49
+ def encode(
50
+ self,
51
+ text: str,
52
+ unsqueeze: bool = False,
53
+ covert_tensor: bool = False
54
+ ) -> Union[torch.Tensor, List[int]]:
55
+ # [x,x,x]
56
+ encoded = self.tokenizer.encode(text, add_special_tokens=False)
57
+
58
+ if unsqueeze:
59
+ # tensor: [[x,x,x]]
60
+ return torch.tensor(encoded, dtype=torch.long).unsqueeze(0)
61
+ else:
62
+ # tensor: # [x,x,x]
63
+ if covert_tensor:
64
+ return torch.tensor(encoded, dtype=torch.long)
65
+
66
+ return encoded
67
+
68
+ def batch_encode(
69
+ self,
70
+ text: List[str],
71
+ padding = False,
72
+ truncation = False,
73
+ covert_tensor: bool = False,
74
+ return_attention_mask: bool = False
75
+ ) -> Union[torch.Tensor, List[List[int]]]:
76
+ encoded = self.tokenizer(
77
+ text,
78
+ padding=padding,
79
+ truncation=truncation,
80
+ return_attention_mask=return_attention_mask
81
+ )['input_ids']
82
+
83
+ if covert_tensor:
84
+ encoded = torch.tensor(encoded, dtype=torch.long)
85
+
86
+ return encoded
87
+
88
+ def decode(
89
+ self,
90
+ token: Union[torch.Tensor, List[int]],
91
+ skip_special_tokens: bool = False
92
+ ) -> str:
93
+ return self.tokenizer.decode(token, skip_special_tokens=skip_special_tokens)
94
+
95
+ def batch_decode(
96
+ self,
97
+ tokens: Union[torch.Tensor, List[int], List[List[int]]],
98
+ skip_special_tokens: bool = False
99
+ ) -> List[str]:
100
+ return self.tokenizer.batch_decode(tokens, skip_special_tokens=skip_special_tokens)
101
+
102
+ def encode_to_token(self, text: str, unsqueeze=True, covert_tensor=True):
103
+ warnings.warn('encode_to_token is deprecated. Please use `encode` instead.')
104
+ return self.encode(text, unsqueeze, covert_tensor)
105
+
106
+ def decode_to_text(self, token: torch.Tensor, skip_special_tokens: bool = False) -> str:
107
+ warnings.warn('decode_to_text is deprecated. Please use `decode` instead.')
108
+ return self.decode(token.squeeze(0), skip_special_tokens)
109
+
110
+ def apply_chat_template(
111
+ self,
112
+ conversations: List[Dict[str, str]],
113
+ tokenizer: bool = True,
114
+ add_answer_tag_for_assistant: bool = True,
115
+ unsqueeze=False,
116
+ covert_tensor=False
117
+ ):
118
+ """
119
+ [
120
+ {"role":"system", "content":"system prompt"},
121
+ {"role":"user", "content":"hello?"},
122
+ {"role":"assistant", "content":"hello"},
123
+ {"role":"user", "content":"hello hello?"},
124
+ {"role":"assistant", "think":"thinking", "content":"hello hello"},
125
+ ]
126
+ <system>{system_prompt}</s><user>hello?</s><assistant>hello</s><user>hello hello?</s><assistant><think>thinking</think><answer>hello hello</answer></s>
127
+ """
128
+
129
+ chat_template = ''
130
+ support_roles = {'system': self.text_system, 'user': self.text_user, 'assistant': self.text_assistant}
131
+ for conversation in conversations:
132
+ role = conversation['role']
133
+ if role in support_roles:
134
+ content = conversation['content']
135
+ if add_answer_tag_for_assistant and role == 'assistant':
136
+ content = f"{self.text_answer_start}{content}{self.text_answer_end}"
137
+
138
+ if 'think' in conversation:
139
+ content = f"{self.text_think_start}{conversation['think']}{self.text_think_end}{content}"
140
+
141
+ chat_template = f"{chat_template}{support_roles[role]}{content}{self.text_end}"
142
+
143
+ if tokenizer:
144
+ return self.encode(chat_template, unsqueeze, covert_tensor)
145
+
146
+ return chat_template
147
+
148
+ def get_special_tokens_dict(self):
149
+ return {
150
+ self.text_end: self.end,
151
+ self.text_pad: self.pad,
152
+ self.text_unk: self.unk,
153
+ self.text_user: self.user,
154
+ self.text_assistant: self.assistant,
155
+ self.text_think_start: self.think_start,
156
+ self.text_think_end: self.think_end,
157
+ self.text_answer_start: self.answer_start,
158
+ self.text_answer_end: self.answer_end,
159
+ self.text_system: self.system,
160
+ self.text_image: self.image,
161
+ }
162
+
llm_trainer/tools.py ADDED
@@ -0,0 +1,116 @@
1
+ import os
2
+ from abc import ABC, abstractmethod
3
+ import torch
4
+ from .tokenizer import Tokenizer
5
+ from .parallel import DsParallel, DdpParallel, NoneParallel
6
+ from .log import Logger
7
+
8
+
9
+ parallel_types = {
10
+ 'ds': DsParallel,
11
+ 'ddp': DdpParallel,
12
+ 'none': NoneParallel
13
+ }
14
+
15
+ dtypes = {
16
+ 'float': torch.float,
17
+ 'float16': torch.float16,
18
+ 'float32': torch.float32,
19
+ 'float64': torch.float64
20
+ }
21
+
22
+ class TrainerTools:
23
+ def __init__(self):
24
+ if not hasattr(TrainerTools, "_first_init"):
25
+ TrainerTools._first_init = True
26
+
27
+ self.parallel = self._new_parallel()
28
+
29
+ self.tokenizer = Tokenizer()
30
+ self.use_amp = 'cuda' in self.parallel.device and not isinstance(self.parallel, DsParallel)
31
+
32
+ Logger.std_log(f'word_size={self.parallel.world_size}, use_amp={self.use_amp}')
33
+
34
+ def _new_parallel(self):
35
+ parallel_type = os.environ.get('PARALLEL_TYPE', 'none')
36
+ Logger.std_log(f'parallel_type={parallel_type}')
37
+ return parallel_types[parallel_type]()
38
+
39
+ def __new__(cls, *args, **kwargs):
40
+ if not hasattr(TrainerTools, "_instance"):
41
+ TrainerTools._instance = object.__new__(cls)
42
+
43
+ return TrainerTools._instance
44
+
45
+
46
+ class FileDataset(ABC):
47
+ @abstractmethod
48
+ def __len__(self) -> int: ...
49
+
50
+ @abstractmethod
51
+ def __getitem__(self, idx) -> str: ...
52
+
53
+
54
+ def estimate_data_size(
55
+ file_dataset: FileDataset,
56
+ block_size: int,
57
+ type: str
58
+ ) -> int:
59
+ """
60
+ 估计数据集大小
61
+ """
62
+ data_size = 0
63
+ files_count = len(file_dataset)
64
+
65
+ if type == 'sft':
66
+ from .dataset import SFTDataset
67
+ for idx in range(files_count):
68
+ dataset = SFTDataset(file_dataset[idx], block_size)
69
+ data_size += len(dataset)
70
+ elif type == 'dpo':
71
+ from .dataset import DPODataset
72
+ for idx in range(files_count):
73
+ dataset = DPODataset(file_dataset[idx], block_size)
74
+ data_size += len(dataset)
75
+ elif type == 'grpo' or type == 'ppo':
76
+ from .dataset import RLDataset
77
+ for idx in range(files_count):
78
+ dataset = RLDataset(file_dataset[idx])
79
+ data_size += len(dataset)
80
+ else:
81
+ from .dataset import PretrainDataset
82
+ for idx in range(files_count):
83
+ dataset = PretrainDataset(
84
+ file_dataset[idx],
85
+ block_size,
86
+ block_size
87
+ )
88
+ data_size += len(dataset)
89
+
90
+ return data_size
91
+
92
+
93
+ def extract_policy_weights_from_ppo(model_config, ppo_weights):
94
+ from llm_model import LlmModel
95
+ from .ppo_trainer import PolicyAndValueModelWrapper, ValueModel
96
+
97
+ policy_model = LlmModel(model_config)
98
+ value_model = ValueModel(LlmModel(model_config))
99
+
100
+ wrapper = PolicyAndValueModelWrapper(policy_model, value_model)
101
+ wrapper.load_state_dict(ppo_weights)
102
+
103
+ return wrapper.policy_model.state_dict()
104
+
105
+
106
+ def extract_value_weights_from_ppo(model_config, ppo_weights):
107
+ from llm_model import LlmModel
108
+ from .ppo_trainer import PolicyAndValueModelWrapper, ValueModel
109
+
110
+ policy_model = LlmModel(model_config)
111
+ value_model = ValueModel(LlmModel(model_config))
112
+
113
+ wrapper = PolicyAndValueModelWrapper(policy_model, value_model)
114
+ wrapper.load_state_dict(ppo_weights)
115
+
116
+ return wrapper.value_model.state_dict()