project-llm-trainer 0.13.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/__init__.py +13 -0
- llm_trainer/base_trainer.py +707 -0
- llm_trainer/checkpoint.py +114 -0
- llm_trainer/dataset.py +335 -0
- llm_trainer/dpo_trainer.py +311 -0
- llm_trainer/ds_checkpoint.py +72 -0
- llm_trainer/eval.py +33 -0
- llm_trainer/generate_utils.py +463 -0
- llm_trainer/grpo_trainer.py +410 -0
- llm_trainer/log.py +65 -0
- llm_trainer/loss.py +266 -0
- llm_trainer/parallel.py +220 -0
- llm_trainer/partition_utils.py +219 -0
- llm_trainer/ppo_trainer.py +686 -0
- llm_trainer/scheduler.py +220 -0
- llm_trainer/sft_trainer.py +97 -0
- llm_trainer/tokenizer.py +162 -0
- llm_trainer/tools.py +116 -0
- llm_trainer/train_configs.py +327 -0
- llm_trainer/trainer.py +34 -0
- llm_trainer/utils.py +630 -0
- project_llm_trainer-0.13.4.data/scripts/calc_intermediate_size +15 -0
- project_llm_trainer-0.13.4.data/scripts/ddp_train +21 -0
- project_llm_trainer-0.13.4.data/scripts/ds_train +17 -0
- project_llm_trainer-0.13.4.data/scripts/py_train +12 -0
- project_llm_trainer-0.13.4.data/scripts/smart_train +37 -0
- project_llm_trainer-0.13.4.data/scripts/vis_log +98 -0
- project_llm_trainer-0.13.4.data/scripts/vis_lr +46 -0
- project_llm_trainer-0.13.4.dist-info/METADATA +9 -0
- project_llm_trainer-0.13.4.dist-info/RECORD +32 -0
- project_llm_trainer-0.13.4.dist-info/WHEEL +5 -0
- project_llm_trainer-0.13.4.dist-info/top_level.txt +1 -0
llm_trainer/scheduler.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import List, Optional
|
|
3
|
+
import math
|
|
4
|
+
import torch
|
|
5
|
+
from .log import Logger
|
|
6
|
+
|
|
7
|
+
class LRScheduler(ABC):
|
|
8
|
+
@property
|
|
9
|
+
@abstractmethod
|
|
10
|
+
def cur_steps(self): ...
|
|
11
|
+
|
|
12
|
+
@property
|
|
13
|
+
@abstractmethod
|
|
14
|
+
def cur_lr(self): ...
|
|
15
|
+
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def step(self): ...
|
|
18
|
+
|
|
19
|
+
@abstractmethod
|
|
20
|
+
def can_clip_grad(self): ...
|
|
21
|
+
|
|
22
|
+
@abstractmethod
|
|
23
|
+
def get_ckpt_dict(self) -> dict: ...
|
|
24
|
+
|
|
25
|
+
@abstractmethod
|
|
26
|
+
def restore_ckpt_dict(self, ckpt: dict): ...
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class WarmupCosineAnnealingLRScheduler(LRScheduler):
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
*,
|
|
33
|
+
optimizer: torch.optim.Optimizer,
|
|
34
|
+
warmup_iters: int,
|
|
35
|
+
initial_lr: float,
|
|
36
|
+
min_lr: float,
|
|
37
|
+
max_lr: float,
|
|
38
|
+
cosine_annealing_period: int, # 每个周期的步数
|
|
39
|
+
cosine_annealing_period_mul: int = 0, # 周期长度的倍数
|
|
40
|
+
param_group_indices: Optional[List[int]] = None,
|
|
41
|
+
need_log: bool = False
|
|
42
|
+
):
|
|
43
|
+
super().__init__()
|
|
44
|
+
|
|
45
|
+
self._optimizer = optimizer
|
|
46
|
+
self._initial_lr = initial_lr
|
|
47
|
+
self._min_lr = min_lr
|
|
48
|
+
self._max_lr = max_lr
|
|
49
|
+
self._warmup_iters = warmup_iters
|
|
50
|
+
|
|
51
|
+
self._cosine_annealing_period = cosine_annealing_period
|
|
52
|
+
self._cosine_annealing_period_mul = cosine_annealing_period_mul
|
|
53
|
+
|
|
54
|
+
self.param_group_indices = param_group_indices
|
|
55
|
+
|
|
56
|
+
self.T_cur = 0 # 当前周期内已走过的步数
|
|
57
|
+
self.cycle = 0 # 当前周期编号
|
|
58
|
+
|
|
59
|
+
if warmup_iters != 0:
|
|
60
|
+
self._lr_increment = (max_lr - initial_lr) / warmup_iters
|
|
61
|
+
else:
|
|
62
|
+
self._lr_increment = 0
|
|
63
|
+
|
|
64
|
+
self._steps = -1
|
|
65
|
+
self._current_lr = initial_lr
|
|
66
|
+
self._cosine_annealing_base_lr = None
|
|
67
|
+
|
|
68
|
+
if need_log:
|
|
69
|
+
self.logger = Logger('lr.txt')
|
|
70
|
+
else:
|
|
71
|
+
self.logger = None
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def cur_steps(self):
|
|
75
|
+
return self._steps
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def cur_lr(self):
|
|
79
|
+
return self._current_lr
|
|
80
|
+
|
|
81
|
+
def step(self):
|
|
82
|
+
self._steps += 1
|
|
83
|
+
self._update_lr()
|
|
84
|
+
|
|
85
|
+
def can_clip_grad(self):
|
|
86
|
+
return self._steps > self._warmup_iters
|
|
87
|
+
|
|
88
|
+
def _update_lr(self):
|
|
89
|
+
if self.param_group_indices is None:
|
|
90
|
+
target_groups = self._optimizer.param_groups
|
|
91
|
+
else:
|
|
92
|
+
target_groups = [self._optimizer.param_groups[i] for i in self.param_group_indices]
|
|
93
|
+
|
|
94
|
+
# 如果period_mul是0,则认为没有周期,超过余弦退火总步数,则一直保持最小lr
|
|
95
|
+
if self._cosine_annealing_period_mul == 0 and self._steps >= self._cosine_annealing_period + self._warmup_iters:
|
|
96
|
+
lr = self._min_lr
|
|
97
|
+
for param_group in target_groups:
|
|
98
|
+
param_group['lr'] = lr
|
|
99
|
+
elif self._steps <= self._warmup_iters:
|
|
100
|
+
# Warmup: adjust learning rate linearly
|
|
101
|
+
# (max_lr - initial_lr) / warmup_iters
|
|
102
|
+
lr = self._initial_lr + self._steps * self._lr_increment
|
|
103
|
+
for param_group in target_groups:
|
|
104
|
+
param_group['lr'] = lr
|
|
105
|
+
else:
|
|
106
|
+
if not self._cosine_annealing_base_lr:
|
|
107
|
+
self._cosine_annealing_base_lr = self.cur_lr
|
|
108
|
+
|
|
109
|
+
"""每步更新学习率"""
|
|
110
|
+
# 计算当前周期的最大步数
|
|
111
|
+
T_max = self._cosine_annealing_period * (max(self._cosine_annealing_period_mul, 1) ** self.cycle)
|
|
112
|
+
|
|
113
|
+
# 更新周期状态
|
|
114
|
+
self.T_cur += 1
|
|
115
|
+
calc_t = self.T_cur
|
|
116
|
+
|
|
117
|
+
if self.T_cur >= T_max:
|
|
118
|
+
if self._cosine_annealing_period_mul == 0:
|
|
119
|
+
self.T_cur = T_max
|
|
120
|
+
calc_t = T_max
|
|
121
|
+
else:
|
|
122
|
+
self.cycle += 1
|
|
123
|
+
self.T_cur = 0
|
|
124
|
+
calc_t = T_max
|
|
125
|
+
|
|
126
|
+
# 计算并设置新学习率
|
|
127
|
+
cos_factor = (1 + math.cos(math.pi * calc_t / T_max)) / 2
|
|
128
|
+
lr = self._min_lr + (self._cosine_annealing_base_lr - self._min_lr) * cos_factor
|
|
129
|
+
|
|
130
|
+
for param_group in target_groups:
|
|
131
|
+
param_group['lr'] = lr
|
|
132
|
+
|
|
133
|
+
self._current_lr = lr
|
|
134
|
+
|
|
135
|
+
if self.logger:
|
|
136
|
+
self.logger.log(f"step: {self.cur_steps}, lr: {lr}", log_to_console=False)
|
|
137
|
+
|
|
138
|
+
def get_ckpt_dict(self) -> dict:
|
|
139
|
+
return {
|
|
140
|
+
'cur_lr': self._current_lr,
|
|
141
|
+
'lr_steps': self.cur_steps,
|
|
142
|
+
'cosine_annealing_base_lr': self._cosine_annealing_base_lr,
|
|
143
|
+
't_cur': self.T_cur,
|
|
144
|
+
'cycle': self.cycle,
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
def restore_ckpt_dict(self, ckpt: dict):
|
|
148
|
+
if 'cur_lr' in ckpt:
|
|
149
|
+
self._current_lr = ckpt['cur_lr']
|
|
150
|
+
|
|
151
|
+
if 'lr_steps' in ckpt:
|
|
152
|
+
self._steps = ckpt['lr_steps']
|
|
153
|
+
|
|
154
|
+
if 'cosine_annealing_base_lr' in ckpt:
|
|
155
|
+
self._cosine_annealing_base_lr = ckpt['cosine_annealing_base_lr']
|
|
156
|
+
|
|
157
|
+
if 't_cur' in ckpt:
|
|
158
|
+
self.T_cur = ckpt['t_cur']
|
|
159
|
+
|
|
160
|
+
if 'cycle' in ckpt:
|
|
161
|
+
self.cycle = ckpt['cycle']
|
|
162
|
+
|
|
163
|
+
self._update_lr()
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class NoneLRScheduler(LRScheduler):
|
|
167
|
+
def __init__(self, initial_lr):
|
|
168
|
+
self._current_lr = initial_lr
|
|
169
|
+
|
|
170
|
+
@property
|
|
171
|
+
def cur_steps(self):
|
|
172
|
+
return -1
|
|
173
|
+
|
|
174
|
+
@property
|
|
175
|
+
def cur_lr(self):
|
|
176
|
+
return self._current_lr
|
|
177
|
+
|
|
178
|
+
def step(self): ...
|
|
179
|
+
|
|
180
|
+
def can_clip_grad(self):
|
|
181
|
+
return True
|
|
182
|
+
|
|
183
|
+
def get_ckpt_dict(self) -> dict:
|
|
184
|
+
return {'cur_lr': self._current_lr}
|
|
185
|
+
|
|
186
|
+
def restore_ckpt_dict(self, ckpt: dict):
|
|
187
|
+
if 'cur_lr' in ckpt:
|
|
188
|
+
self._current_lr = ckpt['cur_lr']
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class CompositeLRScheduler(LRScheduler):
|
|
192
|
+
def __init__(self, schedulers: List[LRScheduler]):
|
|
193
|
+
self.schedulers = schedulers
|
|
194
|
+
|
|
195
|
+
@property
|
|
196
|
+
def cur_steps(self):
|
|
197
|
+
return self.schedulers[0].cur_steps if self.schedulers else 0
|
|
198
|
+
|
|
199
|
+
@property
|
|
200
|
+
def cur_lr(self):
|
|
201
|
+
return self.schedulers[0].cur_lr if self.schedulers else 0.0
|
|
202
|
+
|
|
203
|
+
def step(self):
|
|
204
|
+
for scheduler in self.schedulers:
|
|
205
|
+
scheduler.step()
|
|
206
|
+
|
|
207
|
+
def can_clip_grad(self):
|
|
208
|
+
return all(s.can_clip_grad() for s in self.schedulers)
|
|
209
|
+
|
|
210
|
+
def get_ckpt_dict(self) -> dict:
|
|
211
|
+
ckpt = {}
|
|
212
|
+
for i, scheduler in enumerate(self.schedulers):
|
|
213
|
+
ckpt[f'scheduler_{i}'] = scheduler.get_ckpt_dict()
|
|
214
|
+
return ckpt
|
|
215
|
+
|
|
216
|
+
def restore_ckpt_dict(self, ckpt: dict):
|
|
217
|
+
for i, scheduler in enumerate(self.schedulers):
|
|
218
|
+
key = f'scheduler_{i}'
|
|
219
|
+
if key in ckpt:
|
|
220
|
+
scheduler.restore_ckpt_dict(ckpt[key])
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
from typing import Optional, Tuple, List
|
|
2
|
+
from torch.utils.data import Dataset
|
|
3
|
+
|
|
4
|
+
from llm_model import (
|
|
5
|
+
VLMConfig,
|
|
6
|
+
LlmModel,
|
|
7
|
+
VlmModel
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
from .base_trainer import BaseTrainer
|
|
11
|
+
from .train_configs import TrainConfig
|
|
12
|
+
from .dataset import SFTDataset
|
|
13
|
+
from .utils import get_sft_collate_fn
|
|
14
|
+
from .tools import TrainerTools
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SFTTrainer(BaseTrainer):
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
*,
|
|
21
|
+
train_config: TrainConfig,
|
|
22
|
+
eval_prompts: List[str],
|
|
23
|
+
eval_image_tags: Optional[List[str]] = None
|
|
24
|
+
):
|
|
25
|
+
self.sft_config = train_config.sft_config
|
|
26
|
+
self.pixel_values_provider = self.sft_config.pixel_values_provider
|
|
27
|
+
self.eval_image_tags = eval_image_tags
|
|
28
|
+
|
|
29
|
+
super().__init__(
|
|
30
|
+
train_config=train_config,
|
|
31
|
+
eval_prompts=eval_prompts,
|
|
32
|
+
kd_config=self.sft_config.kd_config,
|
|
33
|
+
gradient_accumulation_steps=self.sft_config.gradient_accumulation_steps
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
if isinstance(train_config.model_config, VLMConfig):
|
|
37
|
+
self.pixel_values_provider = self.sft_config.pixel_values_provider
|
|
38
|
+
else:
|
|
39
|
+
self.pixel_values_provider = None
|
|
40
|
+
|
|
41
|
+
def _new_model(self, train_config: TrainConfig):
|
|
42
|
+
if isinstance(train_config.model_config, VLMConfig):
|
|
43
|
+
return VlmModel(train_config.model_config)
|
|
44
|
+
else:
|
|
45
|
+
return LlmModel(train_config.model_config)
|
|
46
|
+
|
|
47
|
+
def _check_freeze_llm_model(self, model):
|
|
48
|
+
# freeze llm model for vlm training
|
|
49
|
+
if self.sft_config.freeze_llm_model:
|
|
50
|
+
for name, param in model.named_parameters():
|
|
51
|
+
if not any(sub_module in name for sub_module in ['multi_modal_projector']):
|
|
52
|
+
param.requires_grad = False
|
|
53
|
+
|
|
54
|
+
# model.embed_tokens.eval()
|
|
55
|
+
# model.layers.eval()
|
|
56
|
+
# model.head_norm.eval()
|
|
57
|
+
# model.lm_head.eval()
|
|
58
|
+
|
|
59
|
+
def _convert_train_args(self) -> Tuple[dict, dict, dict]:
|
|
60
|
+
sft_collate_fn = get_sft_collate_fn(self.sft_config.mask_prompt)
|
|
61
|
+
parallel_kwargs, data_loader_kwargs, sampler_kwargs = super()._convert_train_args()
|
|
62
|
+
data_loader_kwargs.update({"collate_fn": sft_collate_fn})
|
|
63
|
+
|
|
64
|
+
return parallel_kwargs, data_loader_kwargs, sampler_kwargs
|
|
65
|
+
|
|
66
|
+
def _get_pixel_values(self, batch_data):
|
|
67
|
+
if self.pixel_values_provider and 'image_tags' in batch_data:
|
|
68
|
+
image_tags = batch_data['image_tags']
|
|
69
|
+
return self.pixel_values_provider(image_tags).to(TrainerTools().parallel.device)
|
|
70
|
+
|
|
71
|
+
return None
|
|
72
|
+
|
|
73
|
+
def _get_eval_pixel_values_and_tokens_count(self, eval_idx):
|
|
74
|
+
if not self.eval_image_tags:
|
|
75
|
+
return None, None
|
|
76
|
+
|
|
77
|
+
eval_image_tag = self.eval_image_tags[eval_idx]
|
|
78
|
+
if isinstance(self.train_config.model_config, VLMConfig) and self.pixel_values_provider and eval_image_tag:
|
|
79
|
+
return self.pixel_values_provider([eval_image_tag]), self.train_config.model_config.tokens_per_image
|
|
80
|
+
|
|
81
|
+
return None, None
|
|
82
|
+
|
|
83
|
+
def _create_dataset(self, file_idx) -> Tuple[Dataset, str]:
|
|
84
|
+
file_path = self.train_config.file_dataset[file_idx]
|
|
85
|
+
block_size = self.train_config.dataset_block_size
|
|
86
|
+
|
|
87
|
+
image_tag_file_path = None
|
|
88
|
+
tokens_per_image = -1
|
|
89
|
+
|
|
90
|
+
if isinstance(self.train_config.model_config, VLMConfig):
|
|
91
|
+
if self.sft_config.image_tags_file_dataset:
|
|
92
|
+
image_tag_file_path = self.sft_config.image_tags_file_dataset[file_idx]
|
|
93
|
+
|
|
94
|
+
if self.train_config.model_config.tokens_per_image:
|
|
95
|
+
tokens_per_image = self.train_config.model_config.tokens_per_image
|
|
96
|
+
|
|
97
|
+
return SFTDataset(file_path, block_size, image_tag_file_path, tokens_per_image), file_path
|
llm_trainer/tokenizer.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import warnings
|
|
3
|
+
from typing import List, Dict, Union
|
|
4
|
+
from transformers import AutoTokenizer
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Tokenizer:
|
|
9
|
+
def __init__(self):
|
|
10
|
+
self.tokenizer = AutoTokenizer.from_pretrained(os.environ['TOKEN_DIR'])
|
|
11
|
+
|
|
12
|
+
self.text_end = '</s>'
|
|
13
|
+
|
|
14
|
+
self.text_pad = '<pad>'
|
|
15
|
+
self.text_unk = '<unk>'
|
|
16
|
+
|
|
17
|
+
self.text_user = '<user>'
|
|
18
|
+
self.text_assistant = '<assistant>'
|
|
19
|
+
|
|
20
|
+
self.text_think_start = '<think>'
|
|
21
|
+
self.text_think_end = '</think>'
|
|
22
|
+
|
|
23
|
+
self.text_answer_start = '<answer>'
|
|
24
|
+
self.text_answer_end = '</answer>'
|
|
25
|
+
|
|
26
|
+
self.text_system = '<system>'
|
|
27
|
+
|
|
28
|
+
self.text_image = '<image>'
|
|
29
|
+
|
|
30
|
+
self.end = self.tokenizer.convert_tokens_to_ids(self.text_end)
|
|
31
|
+
|
|
32
|
+
self.pad = self.tokenizer.convert_tokens_to_ids(self.text_pad)
|
|
33
|
+
self.unk = self.tokenizer.convert_tokens_to_ids(self.text_unk)
|
|
34
|
+
|
|
35
|
+
self.user = self.tokenizer.convert_tokens_to_ids(self.text_user)
|
|
36
|
+
self.assistant = self.tokenizer.convert_tokens_to_ids(self.text_assistant)
|
|
37
|
+
|
|
38
|
+
self.think_start = self.tokenizer.convert_tokens_to_ids(self.text_think_start)
|
|
39
|
+
self.think_end = self.tokenizer.convert_tokens_to_ids(self.text_think_end)
|
|
40
|
+
|
|
41
|
+
self.answer_start = self.tokenizer.convert_tokens_to_ids(self.text_answer_start)
|
|
42
|
+
self.answer_end = self.tokenizer.convert_tokens_to_ids(self.text_answer_end)
|
|
43
|
+
|
|
44
|
+
self.system = self.tokenizer.convert_tokens_to_ids(self.text_system)
|
|
45
|
+
self.image = self.tokenizer.convert_tokens_to_ids(self.text_image)
|
|
46
|
+
|
|
47
|
+
self.vocab_size = len(self.tokenizer)
|
|
48
|
+
|
|
49
|
+
def encode(
|
|
50
|
+
self,
|
|
51
|
+
text: str,
|
|
52
|
+
unsqueeze: bool = False,
|
|
53
|
+
covert_tensor: bool = False
|
|
54
|
+
) -> Union[torch.Tensor, List[int]]:
|
|
55
|
+
# [x,x,x]
|
|
56
|
+
encoded = self.tokenizer.encode(text, add_special_tokens=False)
|
|
57
|
+
|
|
58
|
+
if unsqueeze:
|
|
59
|
+
# tensor: [[x,x,x]]
|
|
60
|
+
return torch.tensor(encoded, dtype=torch.long).unsqueeze(0)
|
|
61
|
+
else:
|
|
62
|
+
# tensor: # [x,x,x]
|
|
63
|
+
if covert_tensor:
|
|
64
|
+
return torch.tensor(encoded, dtype=torch.long)
|
|
65
|
+
|
|
66
|
+
return encoded
|
|
67
|
+
|
|
68
|
+
def batch_encode(
|
|
69
|
+
self,
|
|
70
|
+
text: List[str],
|
|
71
|
+
padding = False,
|
|
72
|
+
truncation = False,
|
|
73
|
+
covert_tensor: bool = False,
|
|
74
|
+
return_attention_mask: bool = False
|
|
75
|
+
) -> Union[torch.Tensor, List[List[int]]]:
|
|
76
|
+
encoded = self.tokenizer(
|
|
77
|
+
text,
|
|
78
|
+
padding=padding,
|
|
79
|
+
truncation=truncation,
|
|
80
|
+
return_attention_mask=return_attention_mask
|
|
81
|
+
)['input_ids']
|
|
82
|
+
|
|
83
|
+
if covert_tensor:
|
|
84
|
+
encoded = torch.tensor(encoded, dtype=torch.long)
|
|
85
|
+
|
|
86
|
+
return encoded
|
|
87
|
+
|
|
88
|
+
def decode(
|
|
89
|
+
self,
|
|
90
|
+
token: Union[torch.Tensor, List[int]],
|
|
91
|
+
skip_special_tokens: bool = False
|
|
92
|
+
) -> str:
|
|
93
|
+
return self.tokenizer.decode(token, skip_special_tokens=skip_special_tokens)
|
|
94
|
+
|
|
95
|
+
def batch_decode(
|
|
96
|
+
self,
|
|
97
|
+
tokens: Union[torch.Tensor, List[int], List[List[int]]],
|
|
98
|
+
skip_special_tokens: bool = False
|
|
99
|
+
) -> List[str]:
|
|
100
|
+
return self.tokenizer.batch_decode(tokens, skip_special_tokens=skip_special_tokens)
|
|
101
|
+
|
|
102
|
+
def encode_to_token(self, text: str, unsqueeze=True, covert_tensor=True):
|
|
103
|
+
warnings.warn('encode_to_token is deprecated. Please use `encode` instead.')
|
|
104
|
+
return self.encode(text, unsqueeze, covert_tensor)
|
|
105
|
+
|
|
106
|
+
def decode_to_text(self, token: torch.Tensor, skip_special_tokens: bool = False) -> str:
|
|
107
|
+
warnings.warn('decode_to_text is deprecated. Please use `decode` instead.')
|
|
108
|
+
return self.decode(token.squeeze(0), skip_special_tokens)
|
|
109
|
+
|
|
110
|
+
def apply_chat_template(
|
|
111
|
+
self,
|
|
112
|
+
conversations: List[Dict[str, str]],
|
|
113
|
+
tokenizer: bool = True,
|
|
114
|
+
add_answer_tag_for_assistant: bool = True,
|
|
115
|
+
unsqueeze=False,
|
|
116
|
+
covert_tensor=False
|
|
117
|
+
):
|
|
118
|
+
"""
|
|
119
|
+
[
|
|
120
|
+
{"role":"system", "content":"system prompt"},
|
|
121
|
+
{"role":"user", "content":"hello?"},
|
|
122
|
+
{"role":"assistant", "content":"hello"},
|
|
123
|
+
{"role":"user", "content":"hello hello?"},
|
|
124
|
+
{"role":"assistant", "think":"thinking", "content":"hello hello"},
|
|
125
|
+
]
|
|
126
|
+
<system>{system_prompt}</s><user>hello?</s><assistant>hello</s><user>hello hello?</s><assistant><think>thinking</think><answer>hello hello</answer></s>
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
chat_template = ''
|
|
130
|
+
support_roles = {'system': self.text_system, 'user': self.text_user, 'assistant': self.text_assistant}
|
|
131
|
+
for conversation in conversations:
|
|
132
|
+
role = conversation['role']
|
|
133
|
+
if role in support_roles:
|
|
134
|
+
content = conversation['content']
|
|
135
|
+
if add_answer_tag_for_assistant and role == 'assistant':
|
|
136
|
+
content = f"{self.text_answer_start}{content}{self.text_answer_end}"
|
|
137
|
+
|
|
138
|
+
if 'think' in conversation:
|
|
139
|
+
content = f"{self.text_think_start}{conversation['think']}{self.text_think_end}{content}"
|
|
140
|
+
|
|
141
|
+
chat_template = f"{chat_template}{support_roles[role]}{content}{self.text_end}"
|
|
142
|
+
|
|
143
|
+
if tokenizer:
|
|
144
|
+
return self.encode(chat_template, unsqueeze, covert_tensor)
|
|
145
|
+
|
|
146
|
+
return chat_template
|
|
147
|
+
|
|
148
|
+
def get_special_tokens_dict(self):
|
|
149
|
+
return {
|
|
150
|
+
self.text_end: self.end,
|
|
151
|
+
self.text_pad: self.pad,
|
|
152
|
+
self.text_unk: self.unk,
|
|
153
|
+
self.text_user: self.user,
|
|
154
|
+
self.text_assistant: self.assistant,
|
|
155
|
+
self.text_think_start: self.think_start,
|
|
156
|
+
self.text_think_end: self.think_end,
|
|
157
|
+
self.text_answer_start: self.answer_start,
|
|
158
|
+
self.text_answer_end: self.answer_end,
|
|
159
|
+
self.text_system: self.system,
|
|
160
|
+
self.text_image: self.image,
|
|
161
|
+
}
|
|
162
|
+
|
llm_trainer/tools.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
import torch
|
|
4
|
+
from .tokenizer import Tokenizer
|
|
5
|
+
from .parallel import DsParallel, DdpParallel, NoneParallel
|
|
6
|
+
from .log import Logger
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
parallel_types = {
|
|
10
|
+
'ds': DsParallel,
|
|
11
|
+
'ddp': DdpParallel,
|
|
12
|
+
'none': NoneParallel
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
dtypes = {
|
|
16
|
+
'float': torch.float,
|
|
17
|
+
'float16': torch.float16,
|
|
18
|
+
'float32': torch.float32,
|
|
19
|
+
'float64': torch.float64
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
class TrainerTools:
|
|
23
|
+
def __init__(self):
|
|
24
|
+
if not hasattr(TrainerTools, "_first_init"):
|
|
25
|
+
TrainerTools._first_init = True
|
|
26
|
+
|
|
27
|
+
self.parallel = self._new_parallel()
|
|
28
|
+
|
|
29
|
+
self.tokenizer = Tokenizer()
|
|
30
|
+
self.use_amp = 'cuda' in self.parallel.device and not isinstance(self.parallel, DsParallel)
|
|
31
|
+
|
|
32
|
+
Logger.std_log(f'word_size={self.parallel.world_size}, use_amp={self.use_amp}')
|
|
33
|
+
|
|
34
|
+
def _new_parallel(self):
|
|
35
|
+
parallel_type = os.environ.get('PARALLEL_TYPE', 'none')
|
|
36
|
+
Logger.std_log(f'parallel_type={parallel_type}')
|
|
37
|
+
return parallel_types[parallel_type]()
|
|
38
|
+
|
|
39
|
+
def __new__(cls, *args, **kwargs):
|
|
40
|
+
if not hasattr(TrainerTools, "_instance"):
|
|
41
|
+
TrainerTools._instance = object.__new__(cls)
|
|
42
|
+
|
|
43
|
+
return TrainerTools._instance
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class FileDataset(ABC):
|
|
47
|
+
@abstractmethod
|
|
48
|
+
def __len__(self) -> int: ...
|
|
49
|
+
|
|
50
|
+
@abstractmethod
|
|
51
|
+
def __getitem__(self, idx) -> str: ...
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def estimate_data_size(
|
|
55
|
+
file_dataset: FileDataset,
|
|
56
|
+
block_size: int,
|
|
57
|
+
type: str
|
|
58
|
+
) -> int:
|
|
59
|
+
"""
|
|
60
|
+
估计数据集大小
|
|
61
|
+
"""
|
|
62
|
+
data_size = 0
|
|
63
|
+
files_count = len(file_dataset)
|
|
64
|
+
|
|
65
|
+
if type == 'sft':
|
|
66
|
+
from .dataset import SFTDataset
|
|
67
|
+
for idx in range(files_count):
|
|
68
|
+
dataset = SFTDataset(file_dataset[idx], block_size)
|
|
69
|
+
data_size += len(dataset)
|
|
70
|
+
elif type == 'dpo':
|
|
71
|
+
from .dataset import DPODataset
|
|
72
|
+
for idx in range(files_count):
|
|
73
|
+
dataset = DPODataset(file_dataset[idx], block_size)
|
|
74
|
+
data_size += len(dataset)
|
|
75
|
+
elif type == 'grpo' or type == 'ppo':
|
|
76
|
+
from .dataset import RLDataset
|
|
77
|
+
for idx in range(files_count):
|
|
78
|
+
dataset = RLDataset(file_dataset[idx])
|
|
79
|
+
data_size += len(dataset)
|
|
80
|
+
else:
|
|
81
|
+
from .dataset import PretrainDataset
|
|
82
|
+
for idx in range(files_count):
|
|
83
|
+
dataset = PretrainDataset(
|
|
84
|
+
file_dataset[idx],
|
|
85
|
+
block_size,
|
|
86
|
+
block_size
|
|
87
|
+
)
|
|
88
|
+
data_size += len(dataset)
|
|
89
|
+
|
|
90
|
+
return data_size
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def extract_policy_weights_from_ppo(model_config, ppo_weights):
|
|
94
|
+
from llm_model import LlmModel
|
|
95
|
+
from .ppo_trainer import PolicyAndValueModelWrapper, ValueModel
|
|
96
|
+
|
|
97
|
+
policy_model = LlmModel(model_config)
|
|
98
|
+
value_model = ValueModel(LlmModel(model_config))
|
|
99
|
+
|
|
100
|
+
wrapper = PolicyAndValueModelWrapper(policy_model, value_model)
|
|
101
|
+
wrapper.load_state_dict(ppo_weights)
|
|
102
|
+
|
|
103
|
+
return wrapper.policy_model.state_dict()
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def extract_value_weights_from_ppo(model_config, ppo_weights):
|
|
107
|
+
from llm_model import LlmModel
|
|
108
|
+
from .ppo_trainer import PolicyAndValueModelWrapper, ValueModel
|
|
109
|
+
|
|
110
|
+
policy_model = LlmModel(model_config)
|
|
111
|
+
value_model = ValueModel(LlmModel(model_config))
|
|
112
|
+
|
|
113
|
+
wrapper = PolicyAndValueModelWrapper(policy_model, value_model)
|
|
114
|
+
wrapper.load_state_dict(ppo_weights)
|
|
115
|
+
|
|
116
|
+
return wrapper.value_model.state_dict()
|