project-llm-trainer 0.13.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/__init__.py +13 -0
- llm_trainer/base_trainer.py +707 -0
- llm_trainer/checkpoint.py +114 -0
- llm_trainer/dataset.py +335 -0
- llm_trainer/dpo_trainer.py +311 -0
- llm_trainer/ds_checkpoint.py +72 -0
- llm_trainer/eval.py +33 -0
- llm_trainer/generate_utils.py +463 -0
- llm_trainer/grpo_trainer.py +410 -0
- llm_trainer/log.py +65 -0
- llm_trainer/loss.py +266 -0
- llm_trainer/parallel.py +220 -0
- llm_trainer/partition_utils.py +219 -0
- llm_trainer/ppo_trainer.py +686 -0
- llm_trainer/scheduler.py +220 -0
- llm_trainer/sft_trainer.py +97 -0
- llm_trainer/tokenizer.py +162 -0
- llm_trainer/tools.py +116 -0
- llm_trainer/train_configs.py +327 -0
- llm_trainer/trainer.py +34 -0
- llm_trainer/utils.py +630 -0
- project_llm_trainer-0.13.4.data/scripts/calc_intermediate_size +15 -0
- project_llm_trainer-0.13.4.data/scripts/ddp_train +21 -0
- project_llm_trainer-0.13.4.data/scripts/ds_train +17 -0
- project_llm_trainer-0.13.4.data/scripts/py_train +12 -0
- project_llm_trainer-0.13.4.data/scripts/smart_train +37 -0
- project_llm_trainer-0.13.4.data/scripts/vis_log +98 -0
- project_llm_trainer-0.13.4.data/scripts/vis_lr +46 -0
- project_llm_trainer-0.13.4.dist-info/METADATA +9 -0
- project_llm_trainer-0.13.4.dist-info/RECORD +32 -0
- project_llm_trainer-0.13.4.dist-info/WHEEL +5 -0
- project_llm_trainer-0.13.4.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,311 @@
|
|
|
1
|
+
from typing import Tuple, List, Optional
|
|
2
|
+
import gc
|
|
3
|
+
import torch
|
|
4
|
+
from torch.utils.data import Dataset
|
|
5
|
+
from itertools import islice
|
|
6
|
+
|
|
7
|
+
from .base_trainer import BaseTrainer
|
|
8
|
+
from .train_configs import TrainConfig
|
|
9
|
+
from .dataset import DPODataset
|
|
10
|
+
from .loss import DPOLoss
|
|
11
|
+
from .tools import TrainerTools
|
|
12
|
+
from .utils import (
|
|
13
|
+
autocast,
|
|
14
|
+
get_dpo_collate_fn,
|
|
15
|
+
log_softmax,
|
|
16
|
+
disable_dropout_in_model
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
from .checkpoint import (
|
|
20
|
+
save_checkpoint,
|
|
21
|
+
save_steps,
|
|
22
|
+
)
|
|
23
|
+
from .log import Logger
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class DPOTrainer(BaseTrainer):
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
*,
|
|
30
|
+
train_config: TrainConfig,
|
|
31
|
+
eval_prompts: List[str]
|
|
32
|
+
):
|
|
33
|
+
self.dpo_config = train_config.dpo_config
|
|
34
|
+
super().__init__(
|
|
35
|
+
train_config=train_config,
|
|
36
|
+
eval_prompts=eval_prompts,
|
|
37
|
+
gradient_accumulation_steps=self.dpo_config.gradient_accumulation_steps
|
|
38
|
+
)
|
|
39
|
+
self.ref_model = self._init_ref_model()
|
|
40
|
+
|
|
41
|
+
def _init_ref_model(self):
|
|
42
|
+
ref_model = self._new_model(self.train_config)
|
|
43
|
+
|
|
44
|
+
if self.dpo_config.ref_model_checkpoint:
|
|
45
|
+
ref_model.load_state_dict(self.dpo_config.ref_model_checkpoint)
|
|
46
|
+
self.dpo_config.ref_model_checkpoint = {}
|
|
47
|
+
|
|
48
|
+
ref_model.eval()
|
|
49
|
+
for param in ref_model.parameters():
|
|
50
|
+
param.requires_grad = False
|
|
51
|
+
|
|
52
|
+
ref_model, _ = TrainerTools().parallel.process(
|
|
53
|
+
model=ref_model,
|
|
54
|
+
optimizer=None,
|
|
55
|
+
kwargs=self._init_ref_model_args(),
|
|
56
|
+
save_instance=False
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
return ref_model
|
|
60
|
+
|
|
61
|
+
def _new_model(self, train_config: TrainConfig):
|
|
62
|
+
model = super()._new_model(train_config)
|
|
63
|
+
disable_dropout_in_model(model)
|
|
64
|
+
return model
|
|
65
|
+
|
|
66
|
+
def _init_loss(self):
|
|
67
|
+
criterion = DPOLoss(
|
|
68
|
+
beta=self.dpo_config.loss_beta,
|
|
69
|
+
label_smoothing=self.dpo_config.loss_label_smoothing,
|
|
70
|
+
ipo=self.dpo_config.loss_ipo
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
return criterion, None
|
|
74
|
+
|
|
75
|
+
def _convert_train_args(self) -> Tuple[dict, dict, dict]:
|
|
76
|
+
dpo_collate_fn = get_dpo_collate_fn(self.dpo_config.mask_prompt)
|
|
77
|
+
parallel_kwargs, data_loader_kwargs, sampler_kwargs = super()._convert_train_args()
|
|
78
|
+
data_loader_kwargs.update({"collate_fn": dpo_collate_fn})
|
|
79
|
+
|
|
80
|
+
return parallel_kwargs, data_loader_kwargs, sampler_kwargs
|
|
81
|
+
|
|
82
|
+
def _create_dataset(self, file_idx) -> Tuple[Dataset, str]:
|
|
83
|
+
file_path = self.train_config.file_dataset[file_idx]
|
|
84
|
+
block_size = self.train_config.dataset_block_size
|
|
85
|
+
return DPODataset(file_path, block_size), file_path
|
|
86
|
+
|
|
87
|
+
def _calc_loss(self, inputs, attention_mask, logits, labels): ...
|
|
88
|
+
|
|
89
|
+
def _logprobs(self, logits, labels):
|
|
90
|
+
"""
|
|
91
|
+
Calculate the average log probabilities for a batch of sequences.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
logits (torch.Tensor): Logits from the model with shape (B, T, V)
|
|
95
|
+
labels (torch.Tensor): Ground truth labels with shape (B, T).
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
torch.Tensor: Average log probabilities for each sequence in the batch.
|
|
99
|
+
Shape is (B,) representing the mean log probability for each sequence.
|
|
100
|
+
"""
|
|
101
|
+
loss_masks = (labels != -100)
|
|
102
|
+
|
|
103
|
+
logits = logits[:, :-1, :]
|
|
104
|
+
labels = labels[:, 1:].clone()
|
|
105
|
+
loss_masks = loss_masks[:, 1:]
|
|
106
|
+
|
|
107
|
+
# dummy token; we'll ignore the losses on these tokens later
|
|
108
|
+
labels[labels == -100] = 0
|
|
109
|
+
|
|
110
|
+
# Gather the log probabilities for the actual labels
|
|
111
|
+
per_token_logps = log_softmax(logits, labels)
|
|
112
|
+
|
|
113
|
+
# Apply the mask to set log-probs of padding tokens to 0
|
|
114
|
+
logprobs_sums = (per_token_logps * loss_masks).sum(-1)
|
|
115
|
+
logprobs_means = (per_token_logps * loss_masks).sum(-1) / loss_masks.sum(-1).clamp(min=1.0)
|
|
116
|
+
|
|
117
|
+
return logprobs_sums, logprobs_means
|
|
118
|
+
|
|
119
|
+
def train(self):
|
|
120
|
+
# 梯度累积步数
|
|
121
|
+
gradient_accumulation_steps = max(1, self.gradient_accumulation_steps)
|
|
122
|
+
|
|
123
|
+
loss_accumulation = 0.0
|
|
124
|
+
aux_loss_accumulation = 0.0
|
|
125
|
+
nll_loss_accumulation = 0.0
|
|
126
|
+
batches_accumulated = 0
|
|
127
|
+
|
|
128
|
+
aux_loss_coef = self.train_config.loss_config.aux_loss_coef
|
|
129
|
+
nll_loss_coef = self.dpo_config.nll_loss_coef
|
|
130
|
+
|
|
131
|
+
for epoch in range(self.resume_epoch, self.train_config.n_epochs):
|
|
132
|
+
self.train_model.train()
|
|
133
|
+
file_count = len(self.train_config.file_dataset)
|
|
134
|
+
start_file_idx = self.resume_file_idx if epoch == self.resume_epoch else 0
|
|
135
|
+
|
|
136
|
+
for file_idx in range(start_file_idx, file_count):
|
|
137
|
+
dataset, file_path = self._create_dataset(file_idx)
|
|
138
|
+
train_data_loader = TrainerTools().parallel.process_dataloader(
|
|
139
|
+
dataset=dataset,
|
|
140
|
+
data_loader_kwargs=self.data_loader_kwargs,
|
|
141
|
+
sampler_kwargs=self.sampler_kwargs
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
last_ckpt_batch = 0
|
|
145
|
+
batch_count_per_file = len(train_data_loader)
|
|
146
|
+
|
|
147
|
+
TrainerTools().parallel.on_epoch_start(epoch)
|
|
148
|
+
self._on_file_start(epoch, file_path)
|
|
149
|
+
|
|
150
|
+
skip_batches = 0
|
|
151
|
+
if epoch == self.resume_epoch and file_idx == self.resume_file_idx:
|
|
152
|
+
skip_batches = self.resume_batch_idx
|
|
153
|
+
if skip_batches > 0 and TrainerTools().parallel.is_main_process:
|
|
154
|
+
Logger.std_log(f"Fast forwarding {skip_batches} batches in {file_path}...")
|
|
155
|
+
|
|
156
|
+
data_iterator = iter(train_data_loader)
|
|
157
|
+
if skip_batches > 0:
|
|
158
|
+
data_iterator = islice(data_iterator, skip_batches, None)
|
|
159
|
+
last_ckpt_batch = skip_batches
|
|
160
|
+
|
|
161
|
+
for batch, batch_data in enumerate(data_iterator):
|
|
162
|
+
batch = skip_batches + batch
|
|
163
|
+
|
|
164
|
+
# 是否需要更新梯度
|
|
165
|
+
if gradient_accumulation_steps > 1:
|
|
166
|
+
need_update_grad = (batch + 1) % gradient_accumulation_steps == 0 or batch == batch_count_per_file - 1
|
|
167
|
+
else:
|
|
168
|
+
need_update_grad = True
|
|
169
|
+
|
|
170
|
+
try:
|
|
171
|
+
chosen_inputs: torch.Tensor = batch_data['chosen_inputs'].to(TrainerTools().parallel.device)
|
|
172
|
+
chosen_labels: torch.Tensor = batch_data['chosen_labels'].to(TrainerTools().parallel.device)
|
|
173
|
+
|
|
174
|
+
rejected_inputs: torch.Tensor = batch_data['rejected_inputs'].to(TrainerTools().parallel.device)
|
|
175
|
+
rejected_labels: torch.Tensor = batch_data['rejected_labels'].to(TrainerTools().parallel.device)
|
|
176
|
+
|
|
177
|
+
chosen_attention_masks: torch.Tensor = chosen_inputs != TrainerTools().tokenizer.pad
|
|
178
|
+
rejected_attention_masks: torch.Tensor = rejected_inputs != TrainerTools().tokenizer.pad
|
|
179
|
+
|
|
180
|
+
# 在batch维度concat
|
|
181
|
+
# [chosen, chosen, reject, reject]
|
|
182
|
+
concat_inputs = torch.concat([chosen_inputs, rejected_inputs], dim=0)
|
|
183
|
+
concat_labels = torch.concat([chosen_labels, rejected_labels], dim=0)
|
|
184
|
+
concat_attention_masks = torch.concat([chosen_attention_masks, rejected_attention_masks], dim=0)
|
|
185
|
+
|
|
186
|
+
if TrainerTools().parallel.parallel_train:
|
|
187
|
+
self.train_model.require_backward_grad_sync = need_update_grad
|
|
188
|
+
|
|
189
|
+
with autocast(TrainerTools().parallel.device_type):
|
|
190
|
+
policy_outputs = self.train_model(concat_inputs, attention_mask=concat_attention_masks)
|
|
191
|
+
policy_logprobs_sums, policy_logprobs_means = self._logprobs(policy_outputs['logits'], concat_labels)
|
|
192
|
+
|
|
193
|
+
with torch.no_grad():
|
|
194
|
+
ref_outputs = self.ref_model(concat_inputs, attention_mask=concat_attention_masks)
|
|
195
|
+
ref_logprobs_sums, _ = self._logprobs(ref_outputs['logits'], concat_labels)
|
|
196
|
+
|
|
197
|
+
policy_chosen_logps = policy_logprobs_sums[:chosen_inputs.shape[0]]
|
|
198
|
+
policy_rejected_logps = policy_logprobs_sums[chosen_inputs.shape[0]:]
|
|
199
|
+
|
|
200
|
+
ref_chosen_logps = ref_logprobs_sums[:chosen_inputs.shape[0]]
|
|
201
|
+
ref_rejected_logps = ref_logprobs_sums[chosen_inputs.shape[0]:]
|
|
202
|
+
|
|
203
|
+
nll_loss = -policy_logprobs_means[:chosen_inputs.shape[0]].mean()
|
|
204
|
+
|
|
205
|
+
# calc loss
|
|
206
|
+
loss = self.criterion(
|
|
207
|
+
policy_chosen_logps,
|
|
208
|
+
policy_rejected_logps,
|
|
209
|
+
ref_chosen_logps,
|
|
210
|
+
ref_rejected_logps
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
if aux_loss_coef and policy_outputs.get('aux_loss'):
|
|
214
|
+
aux_loss = aux_loss_coef * policy_outputs.get('aux_loss')
|
|
215
|
+
else:
|
|
216
|
+
aux_loss = torch.tensor(0.0, device=loss.device, dtype=loss.dtype)
|
|
217
|
+
|
|
218
|
+
if nll_loss_coef and nll_loss:
|
|
219
|
+
nll_loss = nll_loss_coef * nll_loss
|
|
220
|
+
else:
|
|
221
|
+
nll_loss = torch.tensor(0.0, device=loss.device, dtype=loss.dtype)
|
|
222
|
+
|
|
223
|
+
if gradient_accumulation_steps > 1:
|
|
224
|
+
loss = loss / gradient_accumulation_steps
|
|
225
|
+
aux_loss = aux_loss / gradient_accumulation_steps
|
|
226
|
+
nll_loss = nll_loss / gradient_accumulation_steps
|
|
227
|
+
|
|
228
|
+
total_loss = loss + aux_loss + nll_loss
|
|
229
|
+
self._backward_loss(total_loss)
|
|
230
|
+
|
|
231
|
+
loss_accumulation += total_loss.detach().item()
|
|
232
|
+
aux_loss_accumulation += aux_loss.detach().item()
|
|
233
|
+
nll_loss_accumulation += nll_loss.detach().item()
|
|
234
|
+
|
|
235
|
+
batches_accumulated += 1
|
|
236
|
+
|
|
237
|
+
if need_update_grad:
|
|
238
|
+
self._apply_grad_clipping()
|
|
239
|
+
self._apply_step()
|
|
240
|
+
|
|
241
|
+
avg_loss, avg_aux_loss, avg_nll_loss = self._avg_loss(
|
|
242
|
+
losses=[
|
|
243
|
+
loss_accumulation,
|
|
244
|
+
aux_loss_accumulation,
|
|
245
|
+
nll_loss_accumulation,
|
|
246
|
+
],
|
|
247
|
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
|
248
|
+
batches_accumulated=batches_accumulated
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
self._log(
|
|
252
|
+
keys={
|
|
253
|
+
'epoch': epoch,
|
|
254
|
+
'file': f'{file_idx + 1}/{file_count}',
|
|
255
|
+
'batch': f'{batch + 1}/{batch_count_per_file}',
|
|
256
|
+
},
|
|
257
|
+
values={
|
|
258
|
+
'loss': avg_loss,
|
|
259
|
+
'moe_aux_loss': avg_aux_loss,
|
|
260
|
+
'nll_loss': avg_nll_loss
|
|
261
|
+
}
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
# reset to default
|
|
265
|
+
loss_accumulation = 0.0
|
|
266
|
+
aux_loss_accumulation = 0.0
|
|
267
|
+
nll_loss_accumulation = 0.0
|
|
268
|
+
batches_accumulated = 0
|
|
269
|
+
|
|
270
|
+
if (batch - last_ckpt_batch) >= self.train_config.eval_config.eval_batch_interval:
|
|
271
|
+
save_checkpoint(model=self.train_model, optimizer=self.optimizer)
|
|
272
|
+
save_steps(
|
|
273
|
+
epoch=epoch,
|
|
274
|
+
file_idx=file_idx,
|
|
275
|
+
batch_idx=batch + 1,
|
|
276
|
+
lr_scheduler=self.lr_scheduler
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
last_ckpt_batch = batch
|
|
280
|
+
self._on_batch_end(tag=f'epoch:{epoch}/batch:{batch}')
|
|
281
|
+
except Exception as e:
|
|
282
|
+
self._on_exception(e, epoch, batch)
|
|
283
|
+
|
|
284
|
+
# 一个文件训练结束后,清理内存
|
|
285
|
+
del train_data_loader
|
|
286
|
+
del dataset
|
|
287
|
+
if hasattr(TrainerTools().parallel, '_sampler'):
|
|
288
|
+
TrainerTools().parallel._sampler = None
|
|
289
|
+
|
|
290
|
+
gc.collect()
|
|
291
|
+
torch.cuda.empty_cache()
|
|
292
|
+
|
|
293
|
+
# end epoch
|
|
294
|
+
|
|
295
|
+
# reset resume state
|
|
296
|
+
self.resume_file_idx = 0
|
|
297
|
+
self.resume_batch_idx = 0
|
|
298
|
+
|
|
299
|
+
save_checkpoint(model=self.train_model, optimizer=self.optimizer)
|
|
300
|
+
save_steps(
|
|
301
|
+
epoch=epoch + 1,
|
|
302
|
+
file_idx=0,
|
|
303
|
+
batch_idx=0,
|
|
304
|
+
lr_scheduler=self.lr_scheduler
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
TrainerTools().parallel.on_epoch_end(epoch)
|
|
308
|
+
self._on_epoch_end(tag=f'epoch:{epoch}')
|
|
309
|
+
|
|
310
|
+
TrainerTools().parallel.destroy()
|
|
311
|
+
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from glob import glob
|
|
3
|
+
from typing import Optional
|
|
4
|
+
import shutil
|
|
5
|
+
import torch
|
|
6
|
+
from torch import nn
|
|
7
|
+
from .tools import TrainerTools
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
import deepspeed
|
|
11
|
+
from deepspeed import DeepSpeedEngine
|
|
12
|
+
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
|
|
13
|
+
except: ...
|
|
14
|
+
|
|
15
|
+
"""
|
|
16
|
+
函数 功能 是否加载模型到内存 是否保存到文件 主要用途
|
|
17
|
+
get_fp32_state_dict_from_zero_checkpoint 从 ZeRO 检查点提取 FP32 状态字典 否 否 获取模型权重,用于推理、迁移等
|
|
18
|
+
load_state_dict_from_zero_checkpoint 从 ZeRO 检查点加载模型和优化器状态 是 否 恢复训练状态,继续训练
|
|
19
|
+
convert_zero_checkpoint_to_fp32_state_dict 将 ZeRO 检查点转换为独立的 FP32 状态字典文件 否 是 创建可移植的 FP32 权重文件,用于部署、分享等
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def save_ds_checkpoint(
|
|
23
|
+
model: nn.Module,
|
|
24
|
+
extra_module: Optional[nn.Module] = None
|
|
25
|
+
):
|
|
26
|
+
assert isinstance(model, DeepSpeedEngine)
|
|
27
|
+
ckpt_dir = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
# 包括model、optimizer等状态
|
|
31
|
+
model.save_checkpoint(save_dir=ckpt_dir)
|
|
32
|
+
except: ...
|
|
33
|
+
|
|
34
|
+
# 只在main rank上执行
|
|
35
|
+
if TrainerTools().parallel.is_main_process:
|
|
36
|
+
if extra_module:
|
|
37
|
+
torch.save(extra_module.state_dict(), os.path.join(ckpt_dir, "extra_module_state_dict.pt"))
|
|
38
|
+
|
|
39
|
+
# 最多保存多少checkpoint,默认为2
|
|
40
|
+
max_to_keep = int(os.environ.get('CKPT_MAX_TO_KEEP', '2'))
|
|
41
|
+
# 删除历史checkpoint
|
|
42
|
+
ckpt_paths = glob(os.path.join(ckpt_dir, "global_*"))
|
|
43
|
+
if len(ckpt_paths) > max_to_keep:
|
|
44
|
+
# 按修改时间排序,找到最旧的目录
|
|
45
|
+
oldest_ckpt = sorted(ckpt_paths, key=os.path.getmtime)[0]
|
|
46
|
+
try:
|
|
47
|
+
shutil.rmtree(oldest_ckpt)
|
|
48
|
+
except: ...
|
|
49
|
+
|
|
50
|
+
TrainerTools().parallel.wait('remove old ds checkpoint')
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def load_ds_checkpoint(
|
|
54
|
+
model: nn.Module,
|
|
55
|
+
load_module_only: bool = False,
|
|
56
|
+
extra_module: Optional[nn.Module] = None
|
|
57
|
+
):
|
|
58
|
+
assert isinstance(model, DeepSpeedEngine)
|
|
59
|
+
ckpt_dir = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
|
|
60
|
+
|
|
61
|
+
# 包括model、optimizer等状态
|
|
62
|
+
if os.path.exists(ckpt_dir):
|
|
63
|
+
model.load_checkpoint(
|
|
64
|
+
load_dir=ckpt_dir,
|
|
65
|
+
load_module_only=load_module_only
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
path = os.path.join(ckpt_dir, "extra_module_state_dict.pt")
|
|
69
|
+
if os.path.exists(path):
|
|
70
|
+
state = torch.load(path, map_location=TrainerTools().parallel.device, weights_only=True)
|
|
71
|
+
extra_module.load_state_dict(state)
|
|
72
|
+
|
llm_trainer/eval.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from .generate_utils import generate
|
|
5
|
+
from .tools import TrainerTools
|
|
6
|
+
from .train_configs import TrainConfig
|
|
7
|
+
from .log import _get_log_dir
|
|
8
|
+
|
|
9
|
+
def submit_gen_task(
|
|
10
|
+
eval_model: torch.nn.Module,
|
|
11
|
+
train_config: TrainConfig,
|
|
12
|
+
tag,
|
|
13
|
+
prompt,
|
|
14
|
+
pixel_values,
|
|
15
|
+
tokens_per_image
|
|
16
|
+
):
|
|
17
|
+
tokens = TrainerTools().tokenizer.encode(prompt, unsqueeze=True, covert_tensor=True)
|
|
18
|
+
max_new_tokens = max(train_config.eval_config.max_seq_len - tokens.shape[1], 0)
|
|
19
|
+
|
|
20
|
+
gen_result = generate(
|
|
21
|
+
eval_model,
|
|
22
|
+
prompt=tokens,
|
|
23
|
+
max_new_tokens=max_new_tokens,
|
|
24
|
+
temperature=train_config.eval_config.temperature,
|
|
25
|
+
k=train_config.eval_config.top_k,
|
|
26
|
+
p=train_config.eval_config.top_p,
|
|
27
|
+
pixel_values=pixel_values,
|
|
28
|
+
tokens_per_image=tokens_per_image,
|
|
29
|
+
device=TrainerTools().parallel.device
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
with open(os.path.join(_get_log_dir(), 'gen.txt'), 'a') as f:
|
|
33
|
+
f.write(f"{tag}, gen->{gen_result}\n")
|