project-llm-trainer 0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/__init__.py +6 -0
- llm_trainer/checkpoint.py +161 -0
- llm_trainer/dataset.py +140 -0
- llm_trainer/dcp.py +93 -0
- llm_trainer/dpo_trainer.py +300 -0
- llm_trainer/ds_checkpoint.py +61 -0
- llm_trainer/eval.py +86 -0
- llm_trainer/generate_utils.py +424 -0
- llm_trainer/grpo_trainer.py +393 -0
- llm_trainer/log.py +16 -0
- llm_trainer/loss.py +171 -0
- llm_trainer/parallel.py +146 -0
- llm_trainer/parallel_ddp.py +39 -0
- llm_trainer/parallel_ds.py +45 -0
- llm_trainer/parallel_fsdp.py +115 -0
- llm_trainer/parallel_none.py +28 -0
- llm_trainer/scheduler.py +138 -0
- llm_trainer/sft_trainer.py +39 -0
- llm_trainer/tokenizer.py +166 -0
- llm_trainer/tools.py +102 -0
- llm_trainer/train_configs.py +445 -0
- llm_trainer/trainer.py +569 -0
- llm_trainer/utils.py +262 -0
- project_llm_trainer-0.3.data/scripts/calc_intermediate_size +15 -0
- project_llm_trainer-0.3.data/scripts/ddp_train +12 -0
- project_llm_trainer-0.3.data/scripts/ds_train +12 -0
- project_llm_trainer-0.3.data/scripts/plot_loss +39 -0
- project_llm_trainer-0.3.data/scripts/plot_lr +41 -0
- project_llm_trainer-0.3.data/scripts/py_train +12 -0
- project_llm_trainer-0.3.data/scripts/smart_train +28 -0
- project_llm_trainer-0.3.dist-info/METADATA +9 -0
- project_llm_trainer-0.3.dist-info/RECORD +34 -0
- project_llm_trainer-0.3.dist-info/WHEEL +5 -0
- project_llm_trainer-0.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from typing import Tuple, List, Optional
|
|
3
|
+
import torch
|
|
4
|
+
from torch.utils.data import Dataset
|
|
5
|
+
import torch.distributed as dist
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
|
|
8
|
+
from llm_model import LlmModel
|
|
9
|
+
|
|
10
|
+
from .parallel_ds import DsParallel
|
|
11
|
+
from .parallel_fsdp import FsdpParallel
|
|
12
|
+
from .trainer import Trainer
|
|
13
|
+
from .train_configs import TrainConfig
|
|
14
|
+
from .dataset import DPODataset
|
|
15
|
+
from .loss import DPOLoss
|
|
16
|
+
from .tools import TrainerTools
|
|
17
|
+
from .utils import get_dpo_collate_fn
|
|
18
|
+
|
|
19
|
+
from .checkpoint import (
|
|
20
|
+
save_checkpoint,
|
|
21
|
+
load_checkpoint_for_eval,
|
|
22
|
+
save_steps,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
class DPOTrainer(Trainer):
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
*,
|
|
29
|
+
train_config: TrainConfig,
|
|
30
|
+
eval_prompts: List[str],
|
|
31
|
+
eval_image_tags: Optional[List[int]] = None
|
|
32
|
+
):
|
|
33
|
+
super().__init__(
|
|
34
|
+
train_config=train_config,
|
|
35
|
+
eval_prompts=eval_prompts,
|
|
36
|
+
eval_image_tags=eval_image_tags
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
self.reference_model = self._init_reference_model()
|
|
40
|
+
|
|
41
|
+
def _init_reference_model(self):
|
|
42
|
+
parallel = TrainerTools().new_parallel()
|
|
43
|
+
|
|
44
|
+
reference_model = LlmModel(self.train_config.model_config)
|
|
45
|
+
if self.train_config.init_state_dict:
|
|
46
|
+
reference_model.load_state_dict(self.train_config.init_state_dict, strict=False)
|
|
47
|
+
self.train_config.init_state_dict = None
|
|
48
|
+
else:
|
|
49
|
+
load_checkpoint_for_eval(model=reference_model, device=parallel.device)
|
|
50
|
+
|
|
51
|
+
reference_model, _ = parallel.process(
|
|
52
|
+
model=reference_model,
|
|
53
|
+
optimizer=None,
|
|
54
|
+
kwargs=self._init_reference_args()
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
parallel.raw_model.eval()
|
|
58
|
+
for param in parallel.raw_model.parameters():
|
|
59
|
+
param.requires_grad = False
|
|
60
|
+
|
|
61
|
+
return reference_model
|
|
62
|
+
|
|
63
|
+
def _init_reference_args(self):
|
|
64
|
+
if isinstance(TrainerTools().parallel, DsParallel) and self.train_config.ds_config:
|
|
65
|
+
parallel_kwargs = {
|
|
66
|
+
'gradient_accumulation_steps': 1,
|
|
67
|
+
'train_micro_batch_size_per_gpu': 1
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
if self.train_config.ds_config.zero_config:
|
|
71
|
+
zero_optimization = {'stage': 0}
|
|
72
|
+
parallel_kwargs['zero_optimization'] = zero_optimization
|
|
73
|
+
|
|
74
|
+
if self.train_config.ds_config.fp16_config:
|
|
75
|
+
fb16_config = self.train_config.ds_config.fp16_config
|
|
76
|
+
fp16 = { 'enabled': fb16_config.enabled }
|
|
77
|
+
|
|
78
|
+
if fb16_config.fp16_opt_level is not None:
|
|
79
|
+
fp16['fp16_opt_level'] = fb16_config.fp16_opt_level
|
|
80
|
+
|
|
81
|
+
parallel_kwargs['fp16'] = fp16
|
|
82
|
+
|
|
83
|
+
if self.train_config.ds_config.bf16_config:
|
|
84
|
+
bf16_config = self.train_config.ds_config.bf16_config
|
|
85
|
+
bf16 = { 'enabled': bf16_config.enabled }
|
|
86
|
+
parallel_kwargs['bf16'] = bf16
|
|
87
|
+
elif isinstance(TrainerTools().parallel, FsdpParallel) and self.train_config.fsdp_config:
|
|
88
|
+
parallel_kwargs = {
|
|
89
|
+
'transformer_layer_cls': self.train_config.fsdp_config.transformer_layer_cls,
|
|
90
|
+
'wrap_policy_num_params': self.train_config.fsdp_config.wrap_policy_num_params,
|
|
91
|
+
'cpu_offload': self.train_config.fsdp_config.cpu_offload,
|
|
92
|
+
'offload_params': self.train_config.fsdp_config.offload_params
|
|
93
|
+
}
|
|
94
|
+
else:
|
|
95
|
+
parallel_kwargs = None
|
|
96
|
+
|
|
97
|
+
return parallel_kwargs
|
|
98
|
+
|
|
99
|
+
def _init_loss(self):
|
|
100
|
+
criterion = DPOLoss(
|
|
101
|
+
beta=self.train_config.dpo_config.loss_beta,
|
|
102
|
+
label_smoothing=self.train_config.dpo_config.loss_label_smoothing,
|
|
103
|
+
ipo=self.train_config.dpo_config.loss_ipo
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
return criterion, None
|
|
107
|
+
|
|
108
|
+
def _convert_train_args(self) -> Tuple[dict, dict, dict, bool]:
|
|
109
|
+
dpo_collate_fn = get_dpo_collate_fn(self.train_config.mask_prompt)
|
|
110
|
+
parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim = super()._convert_train_args()
|
|
111
|
+
data_loader_kwargs.update({"collate_fn": dpo_collate_fn})
|
|
112
|
+
|
|
113
|
+
return parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim
|
|
114
|
+
|
|
115
|
+
def _create_dataset(self, file_path) -> Dataset:
|
|
116
|
+
max_position_embeddings = self.train_config.model_config.max_position_embeddings
|
|
117
|
+
return DPODataset(file_path, max_position_embeddings)
|
|
118
|
+
|
|
119
|
+
def _calc_loss(self, inputs, attention_mask, logits, labels): ...
|
|
120
|
+
|
|
121
|
+
def _log_probs_from_logits(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
|
122
|
+
# https://github.com/OpenRLHF/OpenRLHF/pull/718#issuecomment-2641081881
|
|
123
|
+
if logits.dtype in [torch.float32, torch.float64]:
|
|
124
|
+
logits_labels = torch.gather(logits, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
|
|
125
|
+
logsumexp_values = torch.stack(
|
|
126
|
+
[torch.logsumexp(l, dim=-1) for l in logits] # loop to reduce peak mem consumption
|
|
127
|
+
)
|
|
128
|
+
log_probs_labels = logits_labels - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
|
129
|
+
else:
|
|
130
|
+
log_probs_labels = []
|
|
131
|
+
for row_logits, row_labels in zip(logits, labels): # loop to reduce peak mem consumption
|
|
132
|
+
row_log_probs = F.log_softmax(row_logits, dim=-1)
|
|
133
|
+
row_log_probs_labels = row_log_probs.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)
|
|
134
|
+
log_probs_labels.append(row_log_probs_labels)
|
|
135
|
+
log_probs_labels = torch.stack(log_probs_labels)
|
|
136
|
+
|
|
137
|
+
return log_probs_labels
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _logprobs(self, logits, labels, mask):
|
|
141
|
+
"""
|
|
142
|
+
Calculate the average log probabilities for a batch of sequences.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
logits (torch.Tensor): Logits from the model with shape (B, T, V)
|
|
146
|
+
labels (torch.Tensor): Ground truth labels with shape (B, T).
|
|
147
|
+
mask (torch.Tensor): Mask tensor with shape (B, T) indicating
|
|
148
|
+
which tokens are not padding (1 for valid tokens, 0 for padding).
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
torch.Tensor: Average log probabilities for each sequence in the batch.
|
|
152
|
+
Shape is (B,) representing the mean log probability for each sequence.
|
|
153
|
+
"""
|
|
154
|
+
labels = labels[:, 1:].clone()
|
|
155
|
+
logits = logits[:, :-1, :]
|
|
156
|
+
|
|
157
|
+
# # Shift mask right by one to align with labels
|
|
158
|
+
mask = mask[:, 1:].clone()
|
|
159
|
+
|
|
160
|
+
# dummy token; we'll ignore the losses on these tokens later
|
|
161
|
+
labels[labels == -100] = 0
|
|
162
|
+
|
|
163
|
+
# Gather the log probabilities for the actual labels
|
|
164
|
+
per_token_logps = self._log_probs_from_logits(logits, labels)
|
|
165
|
+
|
|
166
|
+
# Apply the mask to set log-probs of padding tokens to 0
|
|
167
|
+
logprobs_sums = (per_token_logps * mask).sum(-1)
|
|
168
|
+
|
|
169
|
+
# logprobs_means = (per_token_logps * mask).sum(-1) / mask.sum(-1)
|
|
170
|
+
|
|
171
|
+
return logprobs_sums #, -logprobs_means.mean()
|
|
172
|
+
|
|
173
|
+
def train(self):
|
|
174
|
+
# 梯度累积步数
|
|
175
|
+
gradient_accumulation_steps = self.train_config.gradient_accumulation_steps
|
|
176
|
+
global_steps = 0
|
|
177
|
+
loss_accumulation = 0.0
|
|
178
|
+
skipping_train = False
|
|
179
|
+
|
|
180
|
+
aux_loss_coef = self.train_config.loss_config.aux_loss_coef
|
|
181
|
+
|
|
182
|
+
for epoch in range(self.train_config.n_epochs):
|
|
183
|
+
self.train_model.train()
|
|
184
|
+
file_count = len(self.train_config.file_dataset)
|
|
185
|
+
|
|
186
|
+
for file_idx in range(file_count):
|
|
187
|
+
file_path = self.train_config.file_dataset[file_idx]
|
|
188
|
+
|
|
189
|
+
dataset = self._create_dataset(file_path)
|
|
190
|
+
train_data_loader = TrainerTools().parallel.process_dataloader(
|
|
191
|
+
dataset=dataset,
|
|
192
|
+
data_loader_kwargs=self.data_loader_kwargs,
|
|
193
|
+
sampler_kwargs=self.sampler_kwargs
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
last_ckpt_batch = 0
|
|
197
|
+
batch_count_per_file = len(train_data_loader)
|
|
198
|
+
|
|
199
|
+
TrainerTools().parallel.on_epoch_start(epoch)
|
|
200
|
+
self._on_file_start(epoch, file_path)
|
|
201
|
+
|
|
202
|
+
for batch, batch_data in enumerate(train_data_loader):
|
|
203
|
+
global_steps += 1
|
|
204
|
+
if global_steps < self.last_global_steps:
|
|
205
|
+
skipping_train = True
|
|
206
|
+
continue
|
|
207
|
+
|
|
208
|
+
skipping_train = False
|
|
209
|
+
|
|
210
|
+
# 是否需要更新梯度
|
|
211
|
+
if gradient_accumulation_steps > 1:
|
|
212
|
+
need_update_grad = (batch + 1) % gradient_accumulation_steps == 0 or batch == batch_count_per_file - 1
|
|
213
|
+
else:
|
|
214
|
+
need_update_grad = True
|
|
215
|
+
|
|
216
|
+
try:
|
|
217
|
+
chosen_inputs: torch.Tensor = batch_data['chosen_inputs'].to(TrainerTools().parallel.device)
|
|
218
|
+
chosen_labels: torch.Tensor = batch_data['chosen_labels'].to(TrainerTools().parallel.device)
|
|
219
|
+
rejected_inputs: torch.Tensor = batch_data['rejected_inputs'].to(TrainerTools().parallel.device)
|
|
220
|
+
rejected_labels: torch.Tensor = batch_data['rejected_labels'].to(TrainerTools().parallel.device)
|
|
221
|
+
|
|
222
|
+
chosen_attention_mask: torch.Tensor = chosen_inputs != TrainerTools().tokenizer.pad
|
|
223
|
+
rejected_attention_mask: torch.Tensor = rejected_inputs != TrainerTools().tokenizer.pad
|
|
224
|
+
|
|
225
|
+
# 在batch维度concat
|
|
226
|
+
# [chosen, chosen, reject, reject]
|
|
227
|
+
concat_inputs = torch.concat([chosen_inputs, rejected_inputs], dim=0)
|
|
228
|
+
concat_labels = torch.concat([chosen_labels, rejected_labels], dim=0)
|
|
229
|
+
concat_mask = torch.concat([chosen_attention_mask, rejected_attention_mask], dim=0)
|
|
230
|
+
|
|
231
|
+
if TrainerTools().parallel.parallel_train:
|
|
232
|
+
self.train_model.require_backward_grad_sync = need_update_grad
|
|
233
|
+
|
|
234
|
+
with self.ctx:
|
|
235
|
+
policy_outputs = self.train_model(concat_inputs, attention_mask=concat_mask)
|
|
236
|
+
with torch.inference_mode():
|
|
237
|
+
ref_outputs = self.reference_model(concat_inputs, attention_mask=concat_mask)
|
|
238
|
+
|
|
239
|
+
policy_probs = self._logprobs(policy_outputs['logits'], concat_labels, concat_mask)
|
|
240
|
+
ref_probs = self._logprobs(ref_outputs['logits'], concat_labels, concat_mask)
|
|
241
|
+
|
|
242
|
+
# calc loss
|
|
243
|
+
loss = self.criterion(policy_probs, ref_probs)
|
|
244
|
+
|
|
245
|
+
if aux_loss_coef and policy_outputs['aux_loss']:
|
|
246
|
+
loss += aux_loss_coef * policy_outputs['aux_loss']
|
|
247
|
+
|
|
248
|
+
if gradient_accumulation_steps > 1:
|
|
249
|
+
loss = loss / gradient_accumulation_steps
|
|
250
|
+
|
|
251
|
+
loss_accumulation += loss.detach()
|
|
252
|
+
self._backward_loss(loss)
|
|
253
|
+
|
|
254
|
+
if need_update_grad:
|
|
255
|
+
# todo check all_reduce??
|
|
256
|
+
if TrainerTools().parallel.parallel_train:
|
|
257
|
+
dist.all_reduce(loss_accumulation, dist.ReduceOp.AVG)
|
|
258
|
+
|
|
259
|
+
# ds模式已经集成gradient_clipping
|
|
260
|
+
if not isinstance(TrainerTools().parallel, DsParallel) and self.lr_scheduler.can_clip_grad():
|
|
261
|
+
# clip grad
|
|
262
|
+
self.scalar.unscale_(self.optimizer)
|
|
263
|
+
torch.nn.utils.clip_grad_norm_(self.train_model.parameters(), 1.0)
|
|
264
|
+
|
|
265
|
+
self._step()
|
|
266
|
+
|
|
267
|
+
self._log_loss(
|
|
268
|
+
epoch_tag=f'epoch: {epoch}',
|
|
269
|
+
file_tag=f'file: {file_idx + 1}/{file_count}',
|
|
270
|
+
batch_tag=f'batch: {batch}/{batch_count_per_file}',
|
|
271
|
+
loss=loss_accumulation.item()
|
|
272
|
+
)
|
|
273
|
+
# reset to default
|
|
274
|
+
loss_accumulation = 0.0
|
|
275
|
+
except Exception as e:
|
|
276
|
+
self._on_exception(e, epoch, batch)
|
|
277
|
+
finally:
|
|
278
|
+
if need_update_grad:
|
|
279
|
+
save_steps(global_steps=global_steps, lr_scheduler=self.lr_scheduler)
|
|
280
|
+
|
|
281
|
+
if (batch - last_ckpt_batch) >= self.train_config.eval_batch_interval:
|
|
282
|
+
save_checkpoint(model=self.train_model, optimizer=self.optimizer)
|
|
283
|
+
last_ckpt_batch = batch
|
|
284
|
+
self._on_batch_end(tag=f'epoch:{epoch}/batch:{batch}')
|
|
285
|
+
|
|
286
|
+
try:
|
|
287
|
+
del loss
|
|
288
|
+
except UnboundLocalError: ...
|
|
289
|
+
|
|
290
|
+
# end epoch
|
|
291
|
+
if not skipping_train:
|
|
292
|
+
save_checkpoint(model=self.train_model, optimizer=self.optimizer)
|
|
293
|
+
save_steps(global_steps=global_steps, lr_scheduler=self.lr_scheduler)
|
|
294
|
+
TrainerTools().parallel.on_epoch_end(epoch)
|
|
295
|
+
self._on_epoch_end(tag=f'epoch:{epoch}')
|
|
296
|
+
|
|
297
|
+
# 等待checkpoint保存完成
|
|
298
|
+
time.sleep(10)
|
|
299
|
+
TrainerTools().parallel.destroy()
|
|
300
|
+
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from glob import glob
|
|
4
|
+
import shutil
|
|
5
|
+
from torch import nn
|
|
6
|
+
try:
|
|
7
|
+
from deepspeed import DeepSpeedEngine
|
|
8
|
+
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
|
|
9
|
+
except: ...
|
|
10
|
+
|
|
11
|
+
"""
|
|
12
|
+
函数 功能 是否加载模型到内存 是否保存到文件 主要用途
|
|
13
|
+
get_fp32_state_dict_from_zero_checkpoint 从 ZeRO 检查点提取 FP32 状态字典 否 否 获取模型权重,用于推理、迁移等
|
|
14
|
+
load_state_dict_from_zero_checkpoint 从 ZeRO 检查点加载模型和优化器状态 是 否 恢复训练状态,继续训练
|
|
15
|
+
convert_zero_checkpoint_to_fp32_state_dict 将 ZeRO 检查点转换为独立的 FP32 状态字典文件 否 是 创建可移植的 FP32 权重文件,用于部署、分享等
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def save_ds_checkpoint(
|
|
19
|
+
model: nn.Module,
|
|
20
|
+
suffix: Optional[str] = None
|
|
21
|
+
):
|
|
22
|
+
assert isinstance(model, DeepSpeedEngine)
|
|
23
|
+
ckpt_dir = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
|
|
24
|
+
if suffix:
|
|
25
|
+
ckpt_dir = f"{ckpt_dir}_{suffix}"
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
# 包括model、optimizer等状态
|
|
29
|
+
model.save_checkpoint(save_dir=ckpt_dir)
|
|
30
|
+
except:
|
|
31
|
+
return
|
|
32
|
+
|
|
33
|
+
# 删除历史checkpoint
|
|
34
|
+
ckpt_paths = glob(os.path.join(ckpt_dir, "global_*"))
|
|
35
|
+
if len(ckpt_paths) > 2:
|
|
36
|
+
# 按修改时间排序,找到最旧的目录
|
|
37
|
+
oldest_ckpt = sorted(ckpt_paths, key=os.path.getmtime)[0]
|
|
38
|
+
try:
|
|
39
|
+
shutil.rmtree(oldest_ckpt)
|
|
40
|
+
except: ...
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def load_ds_checkpoint(
|
|
44
|
+
model: nn.Module,
|
|
45
|
+
load_module_only: bool = False,
|
|
46
|
+
suffix: Optional[str] = None
|
|
47
|
+
):
|
|
48
|
+
assert isinstance(model, DeepSpeedEngine)
|
|
49
|
+
ckpt_dir = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
|
|
50
|
+
if suffix:
|
|
51
|
+
ckpt_dir = f"{ckpt_dir}_{suffix}"
|
|
52
|
+
|
|
53
|
+
# 包括model、optimizer等状态
|
|
54
|
+
if os.path.exists(ckpt_dir):
|
|
55
|
+
model.load_checkpoint(load_dir=ckpt_dir, load_module_only=load_module_only)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def load_ds_checkpoint_for_eval(model: nn.Module):
|
|
59
|
+
ckpt_dir = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
|
|
60
|
+
state_dict = get_fp32_state_dict_from_zero_checkpoint(ckpt_dir)
|
|
61
|
+
model.load_state_dict(state_dict)
|
llm_trainer/eval.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import time
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from .generate_utils import generate
|
|
6
|
+
from .checkpoint import load_checkpoint_for_eval
|
|
7
|
+
from .log import get_log_dir
|
|
8
|
+
from .tools import TrainerTools
|
|
9
|
+
from .train_configs import EvalConfig
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _eval_task(
|
|
13
|
+
eval_model: torch.nn.Module,
|
|
14
|
+
eval_config: EvalConfig,
|
|
15
|
+
tag,
|
|
16
|
+
prompt,
|
|
17
|
+
pixel_values,
|
|
18
|
+
max_position_embeddings,
|
|
19
|
+
tokens_per_image,
|
|
20
|
+
device
|
|
21
|
+
):
|
|
22
|
+
log_dir = get_log_dir()
|
|
23
|
+
|
|
24
|
+
# 当eval_model不是独立model时可以尝试这个
|
|
25
|
+
# if isinstance(eval_model, FSDP):
|
|
26
|
+
# with FSDP.summon_full_params(module=eval_model, writeback=False, recurse=False):
|
|
27
|
+
# gen = generate(
|
|
28
|
+
# eval_model,
|
|
29
|
+
# prompt=prompt,
|
|
30
|
+
# max_position_embeddings=max_position_embeddings,
|
|
31
|
+
# max_new_tokens=max_new_tokens,
|
|
32
|
+
# # temperature=None,
|
|
33
|
+
# # k=None,
|
|
34
|
+
# # p=None,
|
|
35
|
+
# device='cpu',
|
|
36
|
+
# item_callback=lambda item: write_temp(item)
|
|
37
|
+
# )
|
|
38
|
+
|
|
39
|
+
# ---------
|
|
40
|
+
try:
|
|
41
|
+
load_checkpoint_for_eval(eval_model, device=device)
|
|
42
|
+
except:
|
|
43
|
+
return
|
|
44
|
+
|
|
45
|
+
gen_result = generate(
|
|
46
|
+
eval_model,
|
|
47
|
+
prompt=prompt,
|
|
48
|
+
max_position_embeddings=max_position_embeddings,
|
|
49
|
+
max_new_tokens=eval_config.max_new_tokens,
|
|
50
|
+
temperature=eval_config.temperature,
|
|
51
|
+
k=eval_config.top_k,
|
|
52
|
+
p=eval_config.top_p,
|
|
53
|
+
pixel_values=pixel_values,
|
|
54
|
+
tokens_per_image=tokens_per_image,
|
|
55
|
+
device=device
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
with open(f'{log_dir}gen.txt', 'a') as f:
|
|
59
|
+
f.write(f"{tag}, gen->{gen_result}\n")
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def submit_gen_task(
|
|
63
|
+
eval_model: torch.nn.Module,
|
|
64
|
+
eval_config: EvalConfig,
|
|
65
|
+
tag,
|
|
66
|
+
prompt,
|
|
67
|
+
pixel_values,
|
|
68
|
+
max_position_embeddings,
|
|
69
|
+
tokens_per_image
|
|
70
|
+
):
|
|
71
|
+
# 等待1s,防止deepspeed模式下,找不到checkpoint问题
|
|
72
|
+
time.sleep(1)
|
|
73
|
+
eval_model.to(TrainerTools().parallel.device)
|
|
74
|
+
_eval_task(
|
|
75
|
+
eval_model=eval_model,
|
|
76
|
+
eval_config=eval_config,
|
|
77
|
+
tag=tag,
|
|
78
|
+
prompt=prompt,
|
|
79
|
+
pixel_values=pixel_values,
|
|
80
|
+
max_position_embeddings=max_position_embeddings,
|
|
81
|
+
tokens_per_image=tokens_per_image,
|
|
82
|
+
device=TrainerTools().parallel.device
|
|
83
|
+
)
|
|
84
|
+
eval_model.to('cpu')
|
|
85
|
+
|
|
86
|
+
# threading.Thread(target=_eval_task, args=args).start()
|