omnigenome 0.3.0a1__py3-none-any.whl → 1.0.0b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- omnigenome/__init__.py +26 -258
- {omnigenome-0.3.0a1.dist-info → omnigenome-1.0.0b0.dist-info}/METADATA +9 -10
- omnigenome-1.0.0b0.dist-info/RECORD +6 -0
- omnigenome/auto/__init__.py +0 -3
- omnigenome/auto/auto_bench/__init__.py +0 -12
- omnigenome/auto/auto_bench/auto_bench.py +0 -484
- omnigenome/auto/auto_bench/auto_bench_cli.py +0 -230
- omnigenome/auto/auto_bench/auto_bench_config.py +0 -216
- omnigenome/auto/auto_bench/config_check.py +0 -34
- omnigenome/auto/auto_train/__init__.py +0 -13
- omnigenome/auto/auto_train/auto_train.py +0 -430
- omnigenome/auto/auto_train/auto_train_cli.py +0 -222
- omnigenome/auto/bench_hub/__init__.py +0 -12
- omnigenome/auto/bench_hub/bench_hub.py +0 -25
- omnigenome/cli/__init__.py +0 -13
- omnigenome/cli/commands/__init__.py +0 -13
- omnigenome/cli/commands/base.py +0 -83
- omnigenome/cli/commands/bench/__init__.py +0 -13
- omnigenome/cli/commands/bench/bench_cli.py +0 -202
- omnigenome/cli/commands/rna/__init__.py +0 -13
- omnigenome/cli/commands/rna/rna_design.py +0 -178
- omnigenome/cli/omnigenome_cli.py +0 -128
- omnigenome/src/__init__.py +0 -12
- omnigenome/src/abc/__init__.py +0 -12
- omnigenome/src/abc/abstract_dataset.py +0 -622
- omnigenome/src/abc/abstract_metric.py +0 -114
- omnigenome/src/abc/abstract_model.py +0 -689
- omnigenome/src/abc/abstract_tokenizer.py +0 -267
- omnigenome/src/dataset/__init__.py +0 -16
- omnigenome/src/dataset/omni_dataset.py +0 -435
- omnigenome/src/lora/__init__.py +0 -13
- omnigenome/src/lora/lora_model.py +0 -294
- omnigenome/src/metric/__init__.py +0 -15
- omnigenome/src/metric/classification_metric.py +0 -184
- omnigenome/src/metric/metric.py +0 -199
- omnigenome/src/metric/ranking_metric.py +0 -142
- omnigenome/src/metric/regression_metric.py +0 -191
- omnigenome/src/misc/__init__.py +0 -3
- omnigenome/src/misc/utils.py +0 -499
- omnigenome/src/model/__init__.py +0 -19
- omnigenome/src/model/augmentation/__init__.py +0 -12
- omnigenome/src/model/augmentation/model.py +0 -219
- omnigenome/src/model/classification/__init__.py +0 -12
- omnigenome/src/model/classification/model.py +0 -642
- omnigenome/src/model/embedding/__init__.py +0 -12
- omnigenome/src/model/embedding/model.py +0 -263
- omnigenome/src/model/mlm/__init__.py +0 -12
- omnigenome/src/model/mlm/model.py +0 -177
- omnigenome/src/model/module_utils.py +0 -232
- omnigenome/src/model/regression/__init__.py +0 -12
- omnigenome/src/model/regression/model.py +0 -786
- omnigenome/src/model/regression/resnet.py +0 -483
- omnigenome/src/model/rna_design/__init__.py +0 -12
- omnigenome/src/model/rna_design/model.py +0 -469
- omnigenome/src/model/seq2seq/__init__.py +0 -12
- omnigenome/src/model/seq2seq/model.py +0 -44
- omnigenome/src/tokenizer/__init__.py +0 -16
- omnigenome/src/tokenizer/bpe_tokenizer.py +0 -226
- omnigenome/src/tokenizer/kmers_tokenizer.py +0 -247
- omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +0 -249
- omnigenome/src/trainer/__init__.py +0 -14
- omnigenome/src/trainer/accelerate_trainer.py +0 -739
- omnigenome/src/trainer/hf_trainer.py +0 -75
- omnigenome/src/trainer/trainer.py +0 -579
- omnigenome/utility/__init__.py +0 -3
- omnigenome/utility/dataset_hub/__init__.py +0 -13
- omnigenome/utility/dataset_hub/dataset_hub.py +0 -178
- omnigenome/utility/ensemble.py +0 -324
- omnigenome/utility/hub_utils.py +0 -517
- omnigenome/utility/model_hub/__init__.py +0 -12
- omnigenome/utility/model_hub/model_hub.py +0 -231
- omnigenome/utility/pipeline_hub/__init__.py +0 -12
- omnigenome/utility/pipeline_hub/pipeline.py +0 -483
- omnigenome/utility/pipeline_hub/pipeline_hub.py +0 -129
- omnigenome-0.3.0a1.dist-info/RECORD +0 -78
- omnigenome-0.3.0a1.dist-info/entry_points.txt +0 -3
- {omnigenome-0.3.0a1.dist-info → omnigenome-1.0.0b0.dist-info}/WHEEL +0 -0
- {omnigenome-0.3.0a1.dist-info → omnigenome-1.0.0b0.dist-info}/licenses/LICENSE +0 -0
- {omnigenome-0.3.0a1.dist-info → omnigenome-1.0.0b0.dist-info}/top_level.txt +0 -0
|
@@ -1,739 +0,0 @@
|
|
|
1
|
-
# -*- coding: utf-8 -*-
|
|
2
|
-
# file: trainer.py
|
|
3
|
-
# time: 14:40 06/04/2024
|
|
4
|
-
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
-
# github: https://github.com/yangheng95
|
|
6
|
-
# huggingface: https://huggingface.co/yangheng
|
|
7
|
-
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
8
|
-
# Copyright (C) 2019-2024. All Rights Reserved.
|
|
9
|
-
|
|
10
|
-
import os
|
|
11
|
-
import time
|
|
12
|
-
import numpy as np
|
|
13
|
-
from torch.utils.data import DataLoader
|
|
14
|
-
from tqdm import tqdm
|
|
15
|
-
|
|
16
|
-
import torch
|
|
17
|
-
|
|
18
|
-
from ..misc.utils import env_meta_info, fprint, seed_everything
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
def _infer_optimization_direction(metrics, prev_metrics):
|
|
22
|
-
"""
|
|
23
|
-
Infer the optimization direction based on metric values.
|
|
24
|
-
|
|
25
|
-
This function analyzes the trend of metric values to determine whether
|
|
26
|
-
larger values are better (e.g., accuracy) or smaller values are better
|
|
27
|
-
(e.g., loss).
|
|
28
|
-
|
|
29
|
-
Args:
|
|
30
|
-
metrics (dict): Current metric values
|
|
31
|
-
prev_metrics (list): Previous metric values
|
|
32
|
-
|
|
33
|
-
Returns:
|
|
34
|
-
str: Either 'larger_is_better' or 'smaller_is_better'
|
|
35
|
-
"""
|
|
36
|
-
larger_is_better_metrics = [
|
|
37
|
-
"accuracy",
|
|
38
|
-
"f1",
|
|
39
|
-
"recall",
|
|
40
|
-
"precision",
|
|
41
|
-
"roc_auc",
|
|
42
|
-
"pr_auc",
|
|
43
|
-
"score",
|
|
44
|
-
# ...
|
|
45
|
-
]
|
|
46
|
-
smaller_is_better_metrics = [
|
|
47
|
-
"loss",
|
|
48
|
-
"error",
|
|
49
|
-
"mse",
|
|
50
|
-
"mae",
|
|
51
|
-
"r2",
|
|
52
|
-
"distance",
|
|
53
|
-
# ...
|
|
54
|
-
]
|
|
55
|
-
for metric in larger_is_better_metrics:
|
|
56
|
-
if prev_metrics and metric in list(prev_metrics[0].keys())[0]:
|
|
57
|
-
return "larger_is_better"
|
|
58
|
-
for metric in smaller_is_better_metrics:
|
|
59
|
-
if prev_metrics and metric in list(prev_metrics[0].keys())[0]:
|
|
60
|
-
return "smaller_is_better"
|
|
61
|
-
|
|
62
|
-
fprint(
|
|
63
|
-
"Cannot determine the optimisation direction. Attempting inference from the metrics."
|
|
64
|
-
)
|
|
65
|
-
is_prev_increasing = np.mean(list(prev_metrics[0].values())[0]) < np.mean(
|
|
66
|
-
list(prev_metrics[-1].values())[0]
|
|
67
|
-
)
|
|
68
|
-
is_still_increasing = np.mean(list(prev_metrics[1].values())[0]) < np.mean(
|
|
69
|
-
list(metrics.values())[0]
|
|
70
|
-
)
|
|
71
|
-
fprint(
|
|
72
|
-
"Cannot determine the optimisation direction. Attempting inference from the metrics."
|
|
73
|
-
)
|
|
74
|
-
|
|
75
|
-
if is_prev_increasing and is_still_increasing:
|
|
76
|
-
return "larger_is_better"
|
|
77
|
-
|
|
78
|
-
is_prev_decreasing = np.mean(list(prev_metrics[0].values())[0]) > np.mean(
|
|
79
|
-
list(prev_metrics[-1].values())[0]
|
|
80
|
-
)
|
|
81
|
-
is_still_decreasing = np.mean(list(prev_metrics[1].values())[0]) > np.mean(
|
|
82
|
-
list(metrics.values())
|
|
83
|
-
)
|
|
84
|
-
|
|
85
|
-
if is_prev_decreasing and is_still_decreasing:
|
|
86
|
-
return "smaller_is_better"
|
|
87
|
-
|
|
88
|
-
return "larger_is_better" if is_prev_increasing else "smaller_is_better"
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
class AccelerateTrainer:
|
|
92
|
-
"""
|
|
93
|
-
A distributed training trainer using HuggingFace Accelerate.
|
|
94
|
-
|
|
95
|
-
This trainer provides distributed training capabilities with automatic mixed precision,
|
|
96
|
-
gradient accumulation, and early stopping. It supports both single and multi-GPU
|
|
97
|
-
training with seamless integration with HuggingFace Accelerate.
|
|
98
|
-
|
|
99
|
-
Attributes:
|
|
100
|
-
model: The model to train
|
|
101
|
-
train_loader: DataLoader for training data
|
|
102
|
-
eval_loader: DataLoader for validation data
|
|
103
|
-
test_loader: DataLoader for test data
|
|
104
|
-
epochs: Number of training epochs
|
|
105
|
-
patience: Early stopping patience
|
|
106
|
-
gradient_accumulation_steps: Number of steps for gradient accumulation
|
|
107
|
-
optimizer: The optimizer for training
|
|
108
|
-
loss_fn: Loss function
|
|
109
|
-
compute_metrics: List of metric functions to compute
|
|
110
|
-
accelerator: HuggingFace Accelerate instance
|
|
111
|
-
metrics: Dictionary to store training metrics
|
|
112
|
-
predictions: Dictionary to store predictions
|
|
113
|
-
|
|
114
|
-
Example:
|
|
115
|
-
>>> from omnigenome.src.trainer import AccelerateTrainer
|
|
116
|
-
>>> trainer = AccelerateTrainer(
|
|
117
|
-
... model=model,
|
|
118
|
-
... train_dataset=train_dataset,
|
|
119
|
-
... eval_dataset=eval_dataset,
|
|
120
|
-
... epochs=10,
|
|
121
|
-
... batch_size=32,
|
|
122
|
-
... optimizer=optimizer
|
|
123
|
-
... )
|
|
124
|
-
>>> metrics = trainer.train()
|
|
125
|
-
"""
|
|
126
|
-
|
|
127
|
-
def __init__(
|
|
128
|
-
self,
|
|
129
|
-
model,
|
|
130
|
-
train_dataset: torch.utils.data.Dataset = None,
|
|
131
|
-
eval_dataset: torch.utils.data.Dataset = None,
|
|
132
|
-
test_dataset: torch.utils.data.Dataset = None,
|
|
133
|
-
epochs: int = 3,
|
|
134
|
-
batch_size: int = 8,
|
|
135
|
-
patience: int = -1,
|
|
136
|
-
gradient_accumulation_steps: int = 1,
|
|
137
|
-
optimizer: torch.optim.Optimizer = None,
|
|
138
|
-
loss_fn: torch.nn.Module = None,
|
|
139
|
-
compute_metrics: list | str = None,
|
|
140
|
-
seed: int = 42,
|
|
141
|
-
autocast: str = "float16",
|
|
142
|
-
**kwargs,
|
|
143
|
-
):
|
|
144
|
-
"""
|
|
145
|
-
Initialize the AccelerateTrainer.
|
|
146
|
-
|
|
147
|
-
Args:
|
|
148
|
-
model: The model to train
|
|
149
|
-
train_dataset (torch.utils.data.Dataset, optional): Training dataset
|
|
150
|
-
eval_dataset (torch.utils.data.Dataset, optional): Validation dataset
|
|
151
|
-
test_dataset (torch.utils.data.Dataset, optional): Test dataset
|
|
152
|
-
epochs (int, optional): Number of training epochs. Defaults to 3
|
|
153
|
-
batch_size (int, optional): Batch size for training. Defaults to 8
|
|
154
|
-
patience (int, optional): Early stopping patience. Defaults to -1 (no early stopping)
|
|
155
|
-
gradient_accumulation_steps (int, optional): Number of steps for gradient accumulation. Defaults to 1
|
|
156
|
-
optimizer (torch.optim.Optimizer, optional): Optimizer for training
|
|
157
|
-
loss_fn (torch.nn.Module, optional): Loss function
|
|
158
|
-
compute_metrics (list | str, optional): List of metric functions or single metric function
|
|
159
|
-
seed (int, optional): Random seed for reproducibility. Defaults to 42
|
|
160
|
-
autocast (str, optional): Mixed precision type. Options: 'float16', 'bfloat16', 'no'. Defaults to 'float16'
|
|
161
|
-
**kwargs: Additional keyword arguments
|
|
162
|
-
"""
|
|
163
|
-
self.model = model
|
|
164
|
-
|
|
165
|
-
# DataLoaders
|
|
166
|
-
if kwargs.get("train_loader"):
|
|
167
|
-
self.train_loader = kwargs.get("train_loader")
|
|
168
|
-
self.eval_loader = kwargs.get("eval_loader", None)
|
|
169
|
-
self.test_loader = kwargs.get("test_loader", None)
|
|
170
|
-
else:
|
|
171
|
-
self.train_loader = (
|
|
172
|
-
DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
|
173
|
-
if train_dataset
|
|
174
|
-
else None
|
|
175
|
-
)
|
|
176
|
-
self.eval_loader = (
|
|
177
|
-
DataLoader(eval_dataset, batch_size=batch_size)
|
|
178
|
-
if eval_dataset
|
|
179
|
-
else None
|
|
180
|
-
)
|
|
181
|
-
self.test_loader = (
|
|
182
|
-
DataLoader(test_dataset, batch_size=batch_size)
|
|
183
|
-
if test_dataset
|
|
184
|
-
else None
|
|
185
|
-
)
|
|
186
|
-
self.train_loader = (
|
|
187
|
-
DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
|
188
|
-
if train_dataset
|
|
189
|
-
else None
|
|
190
|
-
)
|
|
191
|
-
self.eval_loader = (
|
|
192
|
-
DataLoader(eval_dataset, batch_size=batch_size)
|
|
193
|
-
if eval_dataset
|
|
194
|
-
else None
|
|
195
|
-
)
|
|
196
|
-
self.test_loader = (
|
|
197
|
-
DataLoader(test_dataset, batch_size=batch_size)
|
|
198
|
-
if test_dataset
|
|
199
|
-
else None
|
|
200
|
-
)
|
|
201
|
-
|
|
202
|
-
self.epochs = epochs
|
|
203
|
-
self.patience = patience
|
|
204
|
-
self.gradient_accumulation_steps = gradient_accumulation_steps
|
|
205
|
-
self.optimizer = optimizer
|
|
206
|
-
self.loss_fn = loss_fn
|
|
207
|
-
self.compute_metrics = (
|
|
208
|
-
compute_metrics if isinstance(compute_metrics, list) else [compute_metrics]
|
|
209
|
-
)
|
|
210
|
-
self.compute_metrics = (
|
|
211
|
-
compute_metrics if isinstance(compute_metrics, list) else [compute_metrics]
|
|
212
|
-
)
|
|
213
|
-
self.seed = seed
|
|
214
|
-
self._optimization_direction = None
|
|
215
|
-
self.trial_name = kwargs.get("trial_name", self.model.__class__.__name__)
|
|
216
|
-
|
|
217
|
-
# Determine mixed precision from `autocast` argument if desired
|
|
218
|
-
if autocast in ["float16", "fp16"]:
|
|
219
|
-
mp_setting = "fp16"
|
|
220
|
-
elif autocast in ["bfloat16", "bf16"]:
|
|
221
|
-
mp_setting = "bf16"
|
|
222
|
-
else:
|
|
223
|
-
mp_setting = "no"
|
|
224
|
-
|
|
225
|
-
# Prepare Accelerator
|
|
226
|
-
from accelerate import Accelerator
|
|
227
|
-
from accelerate import DistributedDataParallelKwargs
|
|
228
|
-
|
|
229
|
-
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
|
230
|
-
|
|
231
|
-
self.accelerator = Accelerator(
|
|
232
|
-
mixed_precision=mp_setting, kwargs_handlers=[ddp_kwargs]
|
|
233
|
-
)
|
|
234
|
-
|
|
235
|
-
self.accelerator = Accelerator(
|
|
236
|
-
mixed_precision=mp_setting, kwargs_handlers=[ddp_kwargs]
|
|
237
|
-
)
|
|
238
|
-
if self.loss_fn is not None:
|
|
239
|
-
self.model.set_loss_fn(self.loss_fn)
|
|
240
|
-
# 创建 dataloaders
|
|
241
|
-
if kwargs.get("train_loader"):
|
|
242
|
-
self.train_loader = kwargs.get("train_loader")
|
|
243
|
-
self.eval_loader = kwargs.get("eval_loader", None)
|
|
244
|
-
self.test_loader = kwargs.get("test_loader", None)
|
|
245
|
-
else:
|
|
246
|
-
self.train_loader = (
|
|
247
|
-
DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
|
248
|
-
if train_dataset
|
|
249
|
-
else None
|
|
250
|
-
)
|
|
251
|
-
self.eval_loader = (
|
|
252
|
-
DataLoader(eval_dataset, batch_size=batch_size)
|
|
253
|
-
if eval_dataset
|
|
254
|
-
else None
|
|
255
|
-
)
|
|
256
|
-
self.test_loader = (
|
|
257
|
-
DataLoader(test_dataset, batch_size=batch_size)
|
|
258
|
-
if test_dataset
|
|
259
|
-
else None
|
|
260
|
-
)
|
|
261
|
-
|
|
262
|
-
# 让 accelerate 处理模型和优化器的准备
|
|
263
|
-
to_prepare = [self.model]
|
|
264
|
-
if optimizer is not None:
|
|
265
|
-
to_prepare.append(optimizer)
|
|
266
|
-
if self.train_loader is not None:
|
|
267
|
-
to_prepare.append(self.train_loader)
|
|
268
|
-
if self.eval_loader is not None:
|
|
269
|
-
to_prepare.append(self.eval_loader)
|
|
270
|
-
if self.test_loader is not None:
|
|
271
|
-
to_prepare.append(self.test_loader)
|
|
272
|
-
|
|
273
|
-
prepared = self.accelerator.prepare(*to_prepare)
|
|
274
|
-
self.model = prepared[0]
|
|
275
|
-
idx = 1
|
|
276
|
-
if optimizer is not None:
|
|
277
|
-
self.optimizer = prepared[idx]
|
|
278
|
-
idx += 1
|
|
279
|
-
if self.train_loader is not None:
|
|
280
|
-
self.train_loader = prepared[idx]
|
|
281
|
-
idx += 1
|
|
282
|
-
if self.eval_loader is not None:
|
|
283
|
-
self.eval_loader = prepared[idx]
|
|
284
|
-
idx += 1
|
|
285
|
-
if self.test_loader is not None:
|
|
286
|
-
self.test_loader = prepared[idx]
|
|
287
|
-
|
|
288
|
-
self.metadata = env_meta_info()
|
|
289
|
-
self.metrics = {}
|
|
290
|
-
|
|
291
|
-
self.predictions = {}
|
|
292
|
-
|
|
293
|
-
def evaluate(self):
|
|
294
|
-
"""
|
|
295
|
-
Evaluate the model on the validation dataset.
|
|
296
|
-
|
|
297
|
-
This method runs the model in evaluation mode and computes metrics
|
|
298
|
-
on the validation dataset. It handles distributed evaluation and
|
|
299
|
-
gathers results from all processes.
|
|
300
|
-
|
|
301
|
-
Returns:
|
|
302
|
-
dict: Dictionary containing evaluation metrics
|
|
303
|
-
|
|
304
|
-
Example:
|
|
305
|
-
>>> metrics = trainer.evaluate()
|
|
306
|
-
>>> print(f"Validation accuracy: {metrics['accuracy']:.4f}")
|
|
307
|
-
"""
|
|
308
|
-
self.model.eval()
|
|
309
|
-
all_truth = []
|
|
310
|
-
all_preds = []
|
|
311
|
-
|
|
312
|
-
# 禁用进度条在非主进程上显示
|
|
313
|
-
it = tqdm(
|
|
314
|
-
self.eval_loader,
|
|
315
|
-
desc="Evaluating",
|
|
316
|
-
disable=not self.accelerator.is_main_process,
|
|
317
|
-
)
|
|
318
|
-
|
|
319
|
-
with torch.no_grad():
|
|
320
|
-
for batch in it:
|
|
321
|
-
output = self.accelerator.unwrap_model(self.model).predict(batch)
|
|
322
|
-
predictions = output["predictions"]
|
|
323
|
-
labels = batch["labels"]
|
|
324
|
-
|
|
325
|
-
# 收集所有进程的预测结果和标签
|
|
326
|
-
gathered_predictions = self.accelerator.gather(predictions)
|
|
327
|
-
gathered_labels = self.accelerator.gather(labels)
|
|
328
|
-
|
|
329
|
-
# 只在主进程中处理收集到的数据
|
|
330
|
-
if self.accelerator.is_main_process:
|
|
331
|
-
gathered_predictions = (
|
|
332
|
-
gathered_predictions.float().cpu().numpy(force=True)
|
|
333
|
-
)
|
|
334
|
-
gathered_labels = gathered_labels.float().cpu().numpy(force=True)
|
|
335
|
-
all_preds.append(gathered_predictions)
|
|
336
|
-
all_truth.append(gathered_labels)
|
|
337
|
-
|
|
338
|
-
# # 同步所有进程
|
|
339
|
-
# self.accelerator.wait_for_everyone()
|
|
340
|
-
|
|
341
|
-
# 只在主进程中计算指标
|
|
342
|
-
if self.accelerator.is_main_process:
|
|
343
|
-
all_preds = np.concatenate(all_preds, axis=0)
|
|
344
|
-
all_truth = np.concatenate(all_truth, axis=0)
|
|
345
|
-
|
|
346
|
-
if not np.all(all_truth == -100):
|
|
347
|
-
valid_metrics = {}
|
|
348
|
-
for metric_func in self.compute_metrics:
|
|
349
|
-
valid_metrics.update(metric_func(all_truth, all_preds))
|
|
350
|
-
else:
|
|
351
|
-
valid_metrics = {
|
|
352
|
-
"Validation labels predictions may be NaN. No metrics calculated.": 0
|
|
353
|
-
}
|
|
354
|
-
|
|
355
|
-
# 打印指标信息
|
|
356
|
-
fprint(valid_metrics)
|
|
357
|
-
else:
|
|
358
|
-
valid_metrics = None
|
|
359
|
-
|
|
360
|
-
self.predictions.update({"valid": {"pred": all_preds, "true": all_truth}})
|
|
361
|
-
|
|
362
|
-
return valid_metrics
|
|
363
|
-
|
|
364
|
-
def test(self):
|
|
365
|
-
"""
|
|
366
|
-
Test the model on the test dataset.
|
|
367
|
-
|
|
368
|
-
This method runs the model in evaluation mode and computes metrics
|
|
369
|
-
on the test dataset. It handles distributed testing and gathers
|
|
370
|
-
results from all processes.
|
|
371
|
-
|
|
372
|
-
Returns:
|
|
373
|
-
dict: Dictionary containing test metrics
|
|
374
|
-
|
|
375
|
-
Example:
|
|
376
|
-
>>> metrics = trainer.test()
|
|
377
|
-
>>> print(f"Test accuracy: {metrics['accuracy']:.4f}")
|
|
378
|
-
"""
|
|
379
|
-
self.model.eval()
|
|
380
|
-
all_truth = []
|
|
381
|
-
all_preds = []
|
|
382
|
-
|
|
383
|
-
it = tqdm(
|
|
384
|
-
self.test_loader,
|
|
385
|
-
desc="Testing",
|
|
386
|
-
disable=not self.accelerator.is_main_process,
|
|
387
|
-
)
|
|
388
|
-
|
|
389
|
-
with torch.no_grad():
|
|
390
|
-
for batch in it:
|
|
391
|
-
output = self.accelerator.unwrap_model(self.model).predict(batch)
|
|
392
|
-
predictions = output["predictions"]
|
|
393
|
-
labels = batch["labels"]
|
|
394
|
-
|
|
395
|
-
gathered_predictions = self.accelerator.gather(predictions)
|
|
396
|
-
gathered_labels = self.accelerator.gather(labels)
|
|
397
|
-
|
|
398
|
-
if self.accelerator.is_main_process:
|
|
399
|
-
gathered_predictions = (
|
|
400
|
-
gathered_predictions.float().cpu().numpy(force=True)
|
|
401
|
-
)
|
|
402
|
-
gathered_labels = gathered_labels.float().cpu().numpy(force=True)
|
|
403
|
-
all_preds.append(gathered_predictions)
|
|
404
|
-
all_truth.append(gathered_labels)
|
|
405
|
-
|
|
406
|
-
# # 同步所有进程
|
|
407
|
-
# self.accelerator.wait_for_everyone()
|
|
408
|
-
|
|
409
|
-
# 只在主进程中计算指标
|
|
410
|
-
if self.accelerator.is_main_process:
|
|
411
|
-
all_preds = np.concatenate(all_preds, axis=0)
|
|
412
|
-
all_truth = np.concatenate(all_truth, axis=0)
|
|
413
|
-
|
|
414
|
-
if not np.all(all_truth == -100):
|
|
415
|
-
test_metrics = {}
|
|
416
|
-
for metric_func in self.compute_metrics:
|
|
417
|
-
test_metrics.update(metric_func(all_truth, all_preds))
|
|
418
|
-
else:
|
|
419
|
-
test_metrics = {
|
|
420
|
-
"Test labels predictions may be NaN. No metrics calculated.": 0
|
|
421
|
-
}
|
|
422
|
-
# 打印指标信息
|
|
423
|
-
fprint(test_metrics)
|
|
424
|
-
else:
|
|
425
|
-
test_metrics = None
|
|
426
|
-
|
|
427
|
-
self.predictions.update({"test": {"pred": all_preds, "true": all_truth}})
|
|
428
|
-
|
|
429
|
-
return test_metrics
|
|
430
|
-
|
|
431
|
-
def train(self, path_to_save=None, **kwargs):
|
|
432
|
-
"""
|
|
433
|
-
Train the model using distributed training.
|
|
434
|
-
|
|
435
|
-
This method performs the complete training loop with validation,
|
|
436
|
-
early stopping, and model checkpointing. It handles distributed
|
|
437
|
-
training across multiple GPUs and processes.
|
|
438
|
-
|
|
439
|
-
Args:
|
|
440
|
-
path_to_save (str, optional): Path to save the trained model
|
|
441
|
-
**kwargs: Additional keyword arguments for model saving
|
|
442
|
-
|
|
443
|
-
Returns:
|
|
444
|
-
dict: Dictionary containing training metrics
|
|
445
|
-
|
|
446
|
-
Example:
|
|
447
|
-
>>> metrics = trainer.train(path_to_save="./checkpoints/model")
|
|
448
|
-
>>> print(f"Best validation accuracy: {metrics['best_valid']['accuracy']:.4f}")
|
|
449
|
-
"""
|
|
450
|
-
seed_everything(self.seed)
|
|
451
|
-
# 在所有进程上创建早停标志
|
|
452
|
-
early_stop_flag = torch.tensor(0, device=self.accelerator.device)
|
|
453
|
-
|
|
454
|
-
# 确保所有进程同步启动
|
|
455
|
-
self.accelerator.wait_for_everyone()
|
|
456
|
-
|
|
457
|
-
# Initial validation or test
|
|
458
|
-
if self.eval_loader is not None and len(self.eval_loader) > 0:
|
|
459
|
-
valid_metrics = self.evaluate()
|
|
460
|
-
else:
|
|
461
|
-
valid_metrics = self.test()
|
|
462
|
-
|
|
463
|
-
# 在主进程中更新指标和保存模型
|
|
464
|
-
if self.accelerator.is_main_process:
|
|
465
|
-
if self._is_metric_better(valid_metrics, stage="valid"):
|
|
466
|
-
self._save_state_dict()
|
|
467
|
-
early_stop_flag = torch.tensor(0, device=self.accelerator.device)
|
|
468
|
-
|
|
469
|
-
# 使用 all_gather 同步早停标志
|
|
470
|
-
gathered_flags = self.accelerator.gather(early_stop_flag)
|
|
471
|
-
early_stop_flag = (
|
|
472
|
-
gathered_flags if gathered_flags.ndim == 0 else gathered_flags[0]
|
|
473
|
-
) # 使用主进程的值
|
|
474
|
-
|
|
475
|
-
for epoch in range(self.epochs):
|
|
476
|
-
self.model.train()
|
|
477
|
-
|
|
478
|
-
train_it = tqdm(
|
|
479
|
-
self.train_loader,
|
|
480
|
-
desc=f"Epoch {epoch + 1}/{self.epochs} Loss",
|
|
481
|
-
disable=not self.accelerator.is_main_process,
|
|
482
|
-
)
|
|
483
|
-
# 使用 accelerator.accumulate 控制梯度累积
|
|
484
|
-
for step, batch in enumerate(train_it):
|
|
485
|
-
train_loss = []
|
|
486
|
-
|
|
487
|
-
with self.accelerator.accumulate(self.model):
|
|
488
|
-
outputs = self.model(**batch)
|
|
489
|
-
if "loss" not in outputs:
|
|
490
|
-
# Generally, the model should return a loss in the outputs via OmniGenBench
|
|
491
|
-
# For the Lora models, the loss is computed separately
|
|
492
|
-
if hasattr(self.model, "loss_function") and callable(self.model.loss_function):
|
|
493
|
-
loss = self.model.loss_function(outputs['logits'], outputs["labels"])
|
|
494
|
-
elif (hasattr(self.model, "model")
|
|
495
|
-
and hasattr(self.model.model, "loss_function")
|
|
496
|
-
and callable(self.model.model.loss_function)):
|
|
497
|
-
loss = self.model.model.loss_function(outputs['logits'], outputs["labels"])
|
|
498
|
-
else:
|
|
499
|
-
raise ValueError(
|
|
500
|
-
"The model does not have a loss function defined. "
|
|
501
|
-
"Please provide a loss function or ensure the model has one."
|
|
502
|
-
)
|
|
503
|
-
else:
|
|
504
|
-
# If the model returns a loss directly
|
|
505
|
-
loss = outputs["loss"]
|
|
506
|
-
|
|
507
|
-
train_loss.append(loss.item() * self.gradient_accumulation_steps)
|
|
508
|
-
train_it.set_description(
|
|
509
|
-
f"Epoch {epoch + 1}/{self.epochs} Loss: {np.nanmean(train_loss):.4f}"
|
|
510
|
-
)
|
|
511
|
-
|
|
512
|
-
self.accelerator.backward(loss)
|
|
513
|
-
|
|
514
|
-
self.optimizer.step()
|
|
515
|
-
self.optimizer.zero_grad()
|
|
516
|
-
|
|
517
|
-
# 同步所有进程后再进行评估
|
|
518
|
-
self.accelerator.wait_for_everyone()
|
|
519
|
-
|
|
520
|
-
if self.eval_loader is not None and len(self.eval_loader) > 0:
|
|
521
|
-
valid_metrics = self.evaluate()
|
|
522
|
-
else:
|
|
523
|
-
valid_metrics = self.test()
|
|
524
|
-
|
|
525
|
-
# 在主进程中更新指标和判断是否需要早停
|
|
526
|
-
if self.accelerator.is_main_process:
|
|
527
|
-
if self._is_metric_better(valid_metrics, stage="valid"):
|
|
528
|
-
self._save_state_dict()
|
|
529
|
-
early_stop_flag = torch.tensor(0, device=self.accelerator.device)
|
|
530
|
-
else:
|
|
531
|
-
early_stop_flag += 1
|
|
532
|
-
|
|
533
|
-
# 使用 all_gather 同步早停标志
|
|
534
|
-
gathered_flags = self.accelerator.gather(early_stop_flag)
|
|
535
|
-
early_stop_flag = (
|
|
536
|
-
gathered_flags if gathered_flags.ndim == 0 else gathered_flags[0]
|
|
537
|
-
) # 使用主进程的值
|
|
538
|
-
|
|
539
|
-
# 检查是否需要早停
|
|
540
|
-
if early_stop_flag.item() > self.patience:
|
|
541
|
-
if self.accelerator.is_main_process:
|
|
542
|
-
print(f"Early stopping at epoch {epoch + 1}.")
|
|
543
|
-
fprint(f"Early stopping at epoch {epoch + 1}.")
|
|
544
|
-
break
|
|
545
|
-
|
|
546
|
-
# 只在主进程中保存检查点
|
|
547
|
-
if path_to_save and self.accelerator.is_main_process:
|
|
548
|
-
_path_to_save = path_to_save + "_epoch_" + str(epoch + 1)
|
|
549
|
-
if valid_metrics:
|
|
550
|
-
for key, value in valid_metrics.items():
|
|
551
|
-
_path_to_save += f"_seed_{self.seed}_{key}_{value:.4f}"
|
|
552
|
-
self.save_model(_path_to_save, **kwargs)
|
|
553
|
-
|
|
554
|
-
# 确保所有进程同步后再进入下一轮
|
|
555
|
-
self.accelerator.wait_for_everyone()
|
|
556
|
-
|
|
557
|
-
# Final test using the best checkpoint
|
|
558
|
-
if self.test_loader is not None and len(self.test_loader) > 0:
|
|
559
|
-
self._load_state_dict()
|
|
560
|
-
self.accelerator.wait_for_everyone() # 确保加载完成后再测试
|
|
561
|
-
test_metrics = self.test()
|
|
562
|
-
if self.accelerator.is_main_process:
|
|
563
|
-
self._is_metric_better(test_metrics, stage="test")
|
|
564
|
-
|
|
565
|
-
# 只在主进程中保存最终模型
|
|
566
|
-
if path_to_save and self.accelerator.is_main_process:
|
|
567
|
-
_path_to_save = path_to_save + "_final"
|
|
568
|
-
if self.metrics.get("test"):
|
|
569
|
-
for key, value in self.metrics["test"][-1].items():
|
|
570
|
-
_path_to_save += f"_seed_{self.seed}_{key}_{value:.4f}"
|
|
571
|
-
self.save_model(_path_to_save, **kwargs)
|
|
572
|
-
|
|
573
|
-
self._remove_state_dict()
|
|
574
|
-
|
|
575
|
-
self.accelerator.free_memory()
|
|
576
|
-
del (
|
|
577
|
-
self.optimizer,
|
|
578
|
-
self.train_loader,
|
|
579
|
-
self.eval_loader,
|
|
580
|
-
self.test_loader,
|
|
581
|
-
)
|
|
582
|
-
|
|
583
|
-
return self.metrics
|
|
584
|
-
|
|
585
|
-
def _is_metric_better(self, metrics, stage="valid"):
|
|
586
|
-
"""
|
|
587
|
-
Check if the current metrics are better than the best metrics so far.
|
|
588
|
-
|
|
589
|
-
Args:
|
|
590
|
-
metrics (dict): Current metrics
|
|
591
|
-
stage (str): Stage of evaluation ('valid' or 'test')
|
|
592
|
-
|
|
593
|
-
Returns:
|
|
594
|
-
bool: True if current metrics are better, False otherwise
|
|
595
|
-
"""
|
|
596
|
-
# 只在主进程中进行metric比较
|
|
597
|
-
if not self.accelerator.is_main_process:
|
|
598
|
-
return False
|
|
599
|
-
|
|
600
|
-
assert stage in [
|
|
601
|
-
"valid",
|
|
602
|
-
"test",
|
|
603
|
-
], "The metrics stage should be either 'valid' or 'test'."
|
|
604
|
-
assert stage in [
|
|
605
|
-
"valid",
|
|
606
|
-
"test",
|
|
607
|
-
], "The metrics stage should be either 'valid' or 'test'."
|
|
608
|
-
|
|
609
|
-
prev_metrics = self.metrics.get(stage, None)
|
|
610
|
-
if stage not in self.metrics:
|
|
611
|
-
self.metrics.update({f"{stage}": [metrics]})
|
|
612
|
-
else:
|
|
613
|
-
self.metrics[f"{stage}"].append(metrics)
|
|
614
|
-
|
|
615
|
-
if "best_valid" not in self.metrics:
|
|
616
|
-
self.metrics.update({"best_valid": metrics})
|
|
617
|
-
return True
|
|
618
|
-
|
|
619
|
-
if prev_metrics is None:
|
|
620
|
-
return False
|
|
621
|
-
|
|
622
|
-
self._optimization_direction = (
|
|
623
|
-
_infer_optimization_direction(metrics, prev_metrics)
|
|
624
|
-
if self._optimization_direction is None
|
|
625
|
-
else self._optimization_direction
|
|
626
|
-
)
|
|
627
|
-
|
|
628
|
-
if self._optimization_direction == "larger_is_better":
|
|
629
|
-
if np.mean(list(metrics.values())[0]) > np.mean(
|
|
630
|
-
list(self.metrics["best_valid"].values())[0]
|
|
631
|
-
):
|
|
632
|
-
self.metrics.update({"best_valid": metrics})
|
|
633
|
-
return True
|
|
634
|
-
elif self._optimization_direction == "smaller_is_better":
|
|
635
|
-
if np.mean(list(metrics.values())[0]) < np.mean(
|
|
636
|
-
list(self.metrics["best_valid"].values())[0]
|
|
637
|
-
):
|
|
638
|
-
self.metrics.update({"best_valid": metrics})
|
|
639
|
-
return True
|
|
640
|
-
|
|
641
|
-
return False
|
|
642
|
-
|
|
643
|
-
def predict(self, data_loader):
|
|
644
|
-
"""
|
|
645
|
-
Make predictions using the trained model.
|
|
646
|
-
|
|
647
|
-
Args:
|
|
648
|
-
data_loader: DataLoader containing data to predict on
|
|
649
|
-
|
|
650
|
-
Returns:
|
|
651
|
-
dict: Dictionary containing predictions
|
|
652
|
-
"""
|
|
653
|
-
return self.accelerator.unwrap_model(self.model).predict(data_loader)
|
|
654
|
-
|
|
655
|
-
def get_model(self, **kwargs):
|
|
656
|
-
"""
|
|
657
|
-
Get the trained model.
|
|
658
|
-
|
|
659
|
-
Args:
|
|
660
|
-
**kwargs: Additional keyword arguments
|
|
661
|
-
|
|
662
|
-
Returns:
|
|
663
|
-
The trained model
|
|
664
|
-
"""
|
|
665
|
-
return self.model
|
|
666
|
-
|
|
667
|
-
def compute_metrics(self):
|
|
668
|
-
"""
|
|
669
|
-
Compute metrics for evaluation.
|
|
670
|
-
|
|
671
|
-
This method should be implemented by subclasses to provide specific
|
|
672
|
-
metric computation logic.
|
|
673
|
-
|
|
674
|
-
Raises:
|
|
675
|
-
NotImplementedError: If compute_metrics method is not implemented
|
|
676
|
-
"""
|
|
677
|
-
raise NotImplementedError(
|
|
678
|
-
"The compute_metrics() function should be implemented for your model."
|
|
679
|
-
" It should return a dictionary of metrics."
|
|
680
|
-
)
|
|
681
|
-
|
|
682
|
-
def save_model(self, path, overwrite=False, **kwargs):
|
|
683
|
-
"""
|
|
684
|
-
Save the trained model.
|
|
685
|
-
|
|
686
|
-
Args:
|
|
687
|
-
path (str): Path to save the model
|
|
688
|
-
overwrite (bool, optional): Whether to overwrite existing files. Defaults to False
|
|
689
|
-
**kwargs: Additional keyword arguments for model saving
|
|
690
|
-
"""
|
|
691
|
-
# Make certain only one process saves, if you're in distributed mode
|
|
692
|
-
if self.accelerator.is_main_process:
|
|
693
|
-
self.accelerator.unwrap_model(self.model).save(path, overwrite, **kwargs)
|
|
694
|
-
|
|
695
|
-
def _load_state_dict(self):
|
|
696
|
-
"""Load the best model state dictionary."""
|
|
697
|
-
if hasattr(self, "_model_state_dict_path") and os.path.exists(
|
|
698
|
-
self._model_state_dict_path
|
|
699
|
-
):
|
|
700
|
-
weights = torch.load(self._model_state_dict_path, map_location="cpu")
|
|
701
|
-
self.accelerator.unwrap_model(self.model).load_state_dict(weights)
|
|
702
|
-
|
|
703
|
-
def _save_state_dict(self):
|
|
704
|
-
"""Save the current model state dictionary."""
|
|
705
|
-
if not hasattr(self, "_model_state_dict_path"):
|
|
706
|
-
from hashlib import sha256
|
|
707
|
-
|
|
708
|
-
time_str = time.strftime("%Y%m%d_%H%M%S", time.localtime())
|
|
709
|
-
hash_digest = sha256(self.__repr__().encode("utf-8")).hexdigest()
|
|
710
|
-
self._model_state_dict_path = f"tmp_ckpt_{time_str}_{hash_digest}.pt"
|
|
711
|
-
|
|
712
|
-
if os.path.exists(self._model_state_dict_path):
|
|
713
|
-
os.remove(self._model_state_dict_path)
|
|
714
|
-
|
|
715
|
-
# Use accelerator to gather model weights on one process
|
|
716
|
-
if self.accelerator.is_main_process:
|
|
717
|
-
torch.save(
|
|
718
|
-
self.accelerator.unwrap_model(self.model).state_dict(),
|
|
719
|
-
self._model_state_dict_path,
|
|
720
|
-
)
|
|
721
|
-
torch.save(
|
|
722
|
-
self.accelerator.unwrap_model(self.model).state_dict(),
|
|
723
|
-
self._model_state_dict_path,
|
|
724
|
-
)
|
|
725
|
-
|
|
726
|
-
def _remove_state_dict(self):
|
|
727
|
-
"""Remove the temporary model state dictionary file."""
|
|
728
|
-
if not hasattr(self, "_model_state_dict_path"):
|
|
729
|
-
from hashlib import sha256
|
|
730
|
-
|
|
731
|
-
time_str = time.strftime("%Y%m%d_%H%M%S", time.localtime())
|
|
732
|
-
hash_digest = sha256(self.__repr__().encode("utf-8")).hexdigest()
|
|
733
|
-
self._model_state_dict_path = f"tmp_ckpt_{time_str}_{hash_digest}.pt"
|
|
734
|
-
|
|
735
|
-
if (
|
|
736
|
-
os.path.exists(self._model_state_dict_path)
|
|
737
|
-
and self.accelerator.is_main_process
|
|
738
|
-
):
|
|
739
|
-
os.remove(self._model_state_dict_path)
|