omnigenome 0.3.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of omnigenome might be problematic. Click here for more details.
- omnigenome/__init__.py +281 -0
- omnigenome/auto/__init__.py +3 -0
- omnigenome/auto/auto_bench/__init__.py +12 -0
- omnigenome/auto/auto_bench/auto_bench.py +484 -0
- omnigenome/auto/auto_bench/auto_bench_cli.py +230 -0
- omnigenome/auto/auto_bench/auto_bench_config.py +216 -0
- omnigenome/auto/auto_bench/config_check.py +34 -0
- omnigenome/auto/auto_train/__init__.py +13 -0
- omnigenome/auto/auto_train/auto_train.py +430 -0
- omnigenome/auto/auto_train/auto_train_cli.py +222 -0
- omnigenome/auto/bench_hub/__init__.py +12 -0
- omnigenome/auto/bench_hub/bench_hub.py +25 -0
- omnigenome/cli/__init__.py +13 -0
- omnigenome/cli/commands/__init__.py +13 -0
- omnigenome/cli/commands/base.py +83 -0
- omnigenome/cli/commands/bench/__init__.py +13 -0
- omnigenome/cli/commands/bench/bench_cli.py +202 -0
- omnigenome/cli/commands/rna/__init__.py +13 -0
- omnigenome/cli/commands/rna/rna_design.py +178 -0
- omnigenome/cli/omnigenome_cli.py +128 -0
- omnigenome/src/__init__.py +12 -0
- omnigenome/src/abc/__init__.py +12 -0
- omnigenome/src/abc/abstract_dataset.py +622 -0
- omnigenome/src/abc/abstract_metric.py +114 -0
- omnigenome/src/abc/abstract_model.py +689 -0
- omnigenome/src/abc/abstract_tokenizer.py +267 -0
- omnigenome/src/dataset/__init__.py +16 -0
- omnigenome/src/dataset/omni_dataset.py +435 -0
- omnigenome/src/lora/__init__.py +13 -0
- omnigenome/src/lora/lora_model.py +294 -0
- omnigenome/src/metric/__init__.py +15 -0
- omnigenome/src/metric/classification_metric.py +184 -0
- omnigenome/src/metric/metric.py +199 -0
- omnigenome/src/metric/ranking_metric.py +142 -0
- omnigenome/src/metric/regression_metric.py +191 -0
- omnigenome/src/misc/__init__.py +3 -0
- omnigenome/src/misc/utils.py +439 -0
- omnigenome/src/model/__init__.py +19 -0
- omnigenome/src/model/augmentation/__init__.py +12 -0
- omnigenome/src/model/augmentation/model.py +219 -0
- omnigenome/src/model/classification/__init__.py +12 -0
- omnigenome/src/model/classification/model.py +642 -0
- omnigenome/src/model/embedding/__init__.py +12 -0
- omnigenome/src/model/embedding/model.py +263 -0
- omnigenome/src/model/mlm/__init__.py +12 -0
- omnigenome/src/model/mlm/model.py +177 -0
- omnigenome/src/model/module_utils.py +232 -0
- omnigenome/src/model/regression/__init__.py +12 -0
- omnigenome/src/model/regression/model.py +786 -0
- omnigenome/src/model/regression/resnet.py +483 -0
- omnigenome/src/model/rna_design/__init__.py +12 -0
- omnigenome/src/model/rna_design/model.py +426 -0
- omnigenome/src/model/seq2seq/__init__.py +12 -0
- omnigenome/src/model/seq2seq/model.py +44 -0
- omnigenome/src/tokenizer/__init__.py +16 -0
- omnigenome/src/tokenizer/bpe_tokenizer.py +226 -0
- omnigenome/src/tokenizer/kmers_tokenizer.py +247 -0
- omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +249 -0
- omnigenome/src/trainer/__init__.py +14 -0
- omnigenome/src/trainer/accelerate_trainer.py +739 -0
- omnigenome/src/trainer/hf_trainer.py +75 -0
- omnigenome/src/trainer/trainer.py +579 -0
- omnigenome/utility/__init__.py +3 -0
- omnigenome/utility/dataset_hub/__init__.py +13 -0
- omnigenome/utility/dataset_hub/dataset_hub.py +178 -0
- omnigenome/utility/ensemble.py +324 -0
- omnigenome/utility/hub_utils.py +517 -0
- omnigenome/utility/model_hub/__init__.py +12 -0
- omnigenome/utility/model_hub/model_hub.py +231 -0
- omnigenome/utility/pipeline_hub/__init__.py +12 -0
- omnigenome/utility/pipeline_hub/pipeline.py +483 -0
- omnigenome/utility/pipeline_hub/pipeline_hub.py +129 -0
- omnigenome-0.3.0a0.dist-info/METADATA +224 -0
- omnigenome-0.3.0a0.dist-info/RECORD +85 -0
- omnigenome-0.3.0a0.dist-info/WHEEL +5 -0
- omnigenome-0.3.0a0.dist-info/entry_points.txt +3 -0
- omnigenome-0.3.0a0.dist-info/licenses/LICENSE +201 -0
- omnigenome-0.3.0a0.dist-info/top_level.txt +2 -0
- tests/__init__.py +9 -0
- tests/conftest.py +160 -0
- tests/test_dataset_patterns.py +291 -0
- tests/test_examples_syntax.py +83 -0
- tests/test_model_loading.py +183 -0
- tests/test_rna_functions.py +255 -0
- tests/test_training_patterns.py +302 -0
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# file: hf_trainer.py
|
|
3
|
+
# time: 14:40 06/04/2024
|
|
4
|
+
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
+
# github: https://github.com/yangheng95
|
|
6
|
+
# huggingface: https://huggingface.co/yangheng
|
|
7
|
+
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
8
|
+
# Copyright (C) 2019-2024. All Rights Reserved.
|
|
9
|
+
"""
|
|
10
|
+
HuggingFace trainer integration for OmniGenome models.
|
|
11
|
+
|
|
12
|
+
This module provides HuggingFace trainer wrappers for OmniGenome models,
|
|
13
|
+
enabling seamless integration with the HuggingFace training ecosystem
|
|
14
|
+
while maintaining OmniGenome-specific functionality.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from transformers import Trainer
|
|
18
|
+
from transformers import TrainingArguments
|
|
19
|
+
|
|
20
|
+
from ... import __name__ as omnigenome_name
|
|
21
|
+
from ... import __version__ as omnigenome_version
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class HFTrainer(Trainer):
|
|
25
|
+
"""
|
|
26
|
+
HuggingFace trainer wrapper for OmniGenome models.
|
|
27
|
+
|
|
28
|
+
This class extends the HuggingFace Trainer to include OmniGenome-specific
|
|
29
|
+
metadata and functionality while maintaining full compatibility with the
|
|
30
|
+
HuggingFace training ecosystem.
|
|
31
|
+
|
|
32
|
+
Attributes:
|
|
33
|
+
metadata: Dictionary containing OmniGenome library information
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self, *args, **kwargs):
|
|
37
|
+
"""
|
|
38
|
+
Initialize the HuggingFace trainer wrapper.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
*args: Positional arguments passed to the parent Trainer
|
|
42
|
+
**kwargs: Keyword arguments passed to the parent Trainer
|
|
43
|
+
"""
|
|
44
|
+
super(HFTrainer, self).__init__(*args, **kwargs)
|
|
45
|
+
self.metadata = {
|
|
46
|
+
"library_name": omnigenome_name,
|
|
47
|
+
"omnigenome_version": omnigenome_version,
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class HFTrainingArguments(TrainingArguments):
|
|
52
|
+
"""
|
|
53
|
+
HuggingFace training arguments wrapper for OmniGenome models.
|
|
54
|
+
|
|
55
|
+
This class extends the HuggingFace TrainingArguments to include
|
|
56
|
+
OmniGenome-specific metadata while maintaining full compatibility
|
|
57
|
+
with the HuggingFace training ecosystem.
|
|
58
|
+
|
|
59
|
+
Attributes:
|
|
60
|
+
metadata: Dictionary containing OmniGenome library information
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(self, *args, **kwargs):
|
|
64
|
+
"""
|
|
65
|
+
Initialize the HuggingFace training arguments wrapper.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
*args: Positional arguments passed to the parent TrainingArguments
|
|
69
|
+
**kwargs: Keyword arguments passed to the parent TrainingArguments
|
|
70
|
+
"""
|
|
71
|
+
super(HFTrainingArguments, self).__init__(*args, **kwargs)
|
|
72
|
+
self.metadata = {
|
|
73
|
+
"library_name": omnigenome_name,
|
|
74
|
+
"omnigenome_version": omnigenome_version,
|
|
75
|
+
}
|
|
@@ -0,0 +1,579 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# file: trainer.py
|
|
3
|
+
# time: 14:40 06/04/2024
|
|
4
|
+
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
+
# github: https://github.com/yangheng95
|
|
6
|
+
# huggingface: https://huggingface.co/yangheng
|
|
7
|
+
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
8
|
+
# Copyright (C) 2019-2024. All Rights Reserved.
|
|
9
|
+
"""
|
|
10
|
+
Training utilities for OmniGenome models.
|
|
11
|
+
|
|
12
|
+
This module provides a comprehensive training framework for OmniGenome models,
|
|
13
|
+
including automatic mixed precision training, early stopping, metric tracking,
|
|
14
|
+
and model checkpointing.
|
|
15
|
+
"""
|
|
16
|
+
import os
|
|
17
|
+
import tempfile
|
|
18
|
+
import autocuda
|
|
19
|
+
import numpy as np
|
|
20
|
+
from torch.utils.data import DataLoader
|
|
21
|
+
from tqdm import tqdm
|
|
22
|
+
|
|
23
|
+
from ..misc.utils import env_meta_info, fprint, seed_everything
|
|
24
|
+
|
|
25
|
+
import torch
|
|
26
|
+
from torch.cuda.amp import GradScaler
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _infer_optimization_direction(metrics, prev_metrics):
|
|
30
|
+
"""
|
|
31
|
+
Infer the optimization direction based on metric names and trends.
|
|
32
|
+
|
|
33
|
+
This function determines whether larger or smaller values are better for
|
|
34
|
+
the given metrics by analyzing metric names and their trends over time.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
metrics (dict): Current metric values
|
|
38
|
+
prev_metrics (list): Previous metric values from multiple epochs
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
str: Either "larger_is_better" or "smaller_is_better"
|
|
42
|
+
"""
|
|
43
|
+
larger_is_better_metrics = [
|
|
44
|
+
"accuracy",
|
|
45
|
+
"f1",
|
|
46
|
+
"recall",
|
|
47
|
+
"precision",
|
|
48
|
+
"roc_auc",
|
|
49
|
+
"pr_auc",
|
|
50
|
+
"score",
|
|
51
|
+
# ...
|
|
52
|
+
]
|
|
53
|
+
smaller_is_better_metrics = [
|
|
54
|
+
"loss",
|
|
55
|
+
"error",
|
|
56
|
+
"mse",
|
|
57
|
+
"mae",
|
|
58
|
+
"r2",
|
|
59
|
+
"distance",
|
|
60
|
+
# ...
|
|
61
|
+
]
|
|
62
|
+
for metric in larger_is_better_metrics:
|
|
63
|
+
if prev_metrics and metric in list(prev_metrics[0].keys())[0]:
|
|
64
|
+
return "larger_is_better"
|
|
65
|
+
for metric in smaller_is_better_metrics:
|
|
66
|
+
if prev_metrics and metric in list(prev_metrics[0].keys())[0]:
|
|
67
|
+
return "smaller_is_better"
|
|
68
|
+
|
|
69
|
+
fprint(
|
|
70
|
+
"Cannot determine the optimisation direction. Attempting inference from the metrics."
|
|
71
|
+
)
|
|
72
|
+
is_prev_increasing = np.mean(list(prev_metrics[0].values())[0]) < np.mean(
|
|
73
|
+
list(prev_metrics[-1].values())[0]
|
|
74
|
+
)
|
|
75
|
+
is_still_increasing = np.mean(list(prev_metrics[1].values())[0]) < np.mean(
|
|
76
|
+
list(metrics.values())[0]
|
|
77
|
+
)
|
|
78
|
+
fprint(
|
|
79
|
+
"Cannot determine the optimisation direction. Attempting inference from the metrics."
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
if is_prev_increasing and is_still_increasing:
|
|
83
|
+
return "larger_is_better"
|
|
84
|
+
|
|
85
|
+
is_prev_decreasing = np.mean(list(prev_metrics[0].values())[0]) > np.mean(
|
|
86
|
+
list(prev_metrics[-1].values())[0]
|
|
87
|
+
)
|
|
88
|
+
is_still_decreasing = np.mean(list(prev_metrics[1].values())[0]) > np.mean(
|
|
89
|
+
list(metrics.values())
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
if is_prev_decreasing and is_still_decreasing:
|
|
93
|
+
return "smaller_is_better"
|
|
94
|
+
|
|
95
|
+
return "larger_is_better" if is_prev_increasing else "smaller_is_better"
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class Trainer:
|
|
99
|
+
"""
|
|
100
|
+
Comprehensive trainer for OmniGenome models.
|
|
101
|
+
|
|
102
|
+
This trainer provides a complete training framework with automatic mixed precision,
|
|
103
|
+
early stopping, metric tracking, and model checkpointing. It supports various
|
|
104
|
+
training configurations and can handle different types of genomic sequence tasks.
|
|
105
|
+
|
|
106
|
+
Attributes:
|
|
107
|
+
model: The model to be trained
|
|
108
|
+
train_loader: DataLoader for training data
|
|
109
|
+
eval_loader: DataLoader for validation data
|
|
110
|
+
test_loader: DataLoader for test data
|
|
111
|
+
epochs: Number of training epochs
|
|
112
|
+
patience: Early stopping patience
|
|
113
|
+
optimizer: Optimizer for training
|
|
114
|
+
loss_fn: Loss function
|
|
115
|
+
compute_metrics: List of metric computation functions
|
|
116
|
+
device: Device to run training on
|
|
117
|
+
scaler: Gradient scaler for mixed precision training
|
|
118
|
+
metrics: Dictionary to store training metrics
|
|
119
|
+
predictions: Dictionary to store model predictions
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
def __init__(
|
|
123
|
+
self,
|
|
124
|
+
model,
|
|
125
|
+
train_dataset: torch.utils.data.Dataset = None,
|
|
126
|
+
eval_dataset: torch.utils.data.Dataset = None,
|
|
127
|
+
test_dataset: torch.utils.data.Dataset = None,
|
|
128
|
+
epochs: int = 3,
|
|
129
|
+
batch_size: int = 8,
|
|
130
|
+
patience: int = -1,
|
|
131
|
+
gradient_accumulation_steps: int = 1,
|
|
132
|
+
optimizer: torch.optim.Optimizer = None,
|
|
133
|
+
loss_fn: torch.nn.Module = None,
|
|
134
|
+
compute_metrics: list | str = None,
|
|
135
|
+
seed: int = 42,
|
|
136
|
+
device: [torch.device | str] = None,
|
|
137
|
+
autocast: str = "float16",
|
|
138
|
+
**kwargs,
|
|
139
|
+
):
|
|
140
|
+
"""
|
|
141
|
+
Initialize the trainer.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
model: The model to be trained
|
|
145
|
+
train_dataset: Training dataset
|
|
146
|
+
eval_dataset: Validation dataset
|
|
147
|
+
test_dataset: Test dataset
|
|
148
|
+
epochs (int): Number of training epochs (default: 3)
|
|
149
|
+
batch_size (int): Batch size for training (default: 8)
|
|
150
|
+
patience (int): Early stopping patience (default: -1, no early stopping)
|
|
151
|
+
gradient_accumulation_steps (int): Gradient accumulation steps (default: 1)
|
|
152
|
+
optimizer: Optimizer for training (default: None)
|
|
153
|
+
loss_fn: Loss function (default: None)
|
|
154
|
+
compute_metrics: Metric computation functions (default: None)
|
|
155
|
+
seed (int): Random seed (default: 42)
|
|
156
|
+
device: Device to run training on (default: None, auto-detect)
|
|
157
|
+
autocast (str): Mixed precision type (default: "float16")
|
|
158
|
+
**kwargs: Additional keyword arguments
|
|
159
|
+
"""
|
|
160
|
+
# sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
|
|
161
|
+
|
|
162
|
+
self.model = model
|
|
163
|
+
|
|
164
|
+
# DataLoaders
|
|
165
|
+
if kwargs.get("train_loader"):
|
|
166
|
+
self.train_loader = kwargs.get("train_loader", None)
|
|
167
|
+
self.eval_loader = kwargs.get("eval_loader", None)
|
|
168
|
+
self.test_loader = kwargs.get("test_loader", None)
|
|
169
|
+
else:
|
|
170
|
+
self.train_loader = DataLoader(
|
|
171
|
+
train_dataset, batch_size=batch_size, shuffle=True
|
|
172
|
+
)
|
|
173
|
+
self.eval_loader = (
|
|
174
|
+
DataLoader(eval_dataset, batch_size=batch_size)
|
|
175
|
+
if eval_dataset
|
|
176
|
+
else None
|
|
177
|
+
)
|
|
178
|
+
self.test_loader = (
|
|
179
|
+
DataLoader(test_dataset, batch_size=batch_size)
|
|
180
|
+
if test_dataset
|
|
181
|
+
else None
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
self.epochs = epochs
|
|
185
|
+
self.patience = patience if patience > 0 else epochs
|
|
186
|
+
self.gradient_accumulation_steps = gradient_accumulation_steps
|
|
187
|
+
self.optimizer = optimizer
|
|
188
|
+
self.loss_fn = loss_fn
|
|
189
|
+
self.compute_metrics = (
|
|
190
|
+
compute_metrics if isinstance(compute_metrics, list) else [compute_metrics]
|
|
191
|
+
)
|
|
192
|
+
self.seed = seed
|
|
193
|
+
self.device = device if device else autocuda.auto_cuda()
|
|
194
|
+
self.device = torch.device(self.device) if isinstance(self.device, str) else self.device
|
|
195
|
+
|
|
196
|
+
self.fast_dtype = {
|
|
197
|
+
"float32": torch.float32,
|
|
198
|
+
"fp32": torch.float32,
|
|
199
|
+
"float16": torch.float16,
|
|
200
|
+
"fp16": torch.float16,
|
|
201
|
+
"bfloat16": torch.bfloat16,
|
|
202
|
+
"bf16": torch.bfloat16,
|
|
203
|
+
}.get(autocast, torch.float16)
|
|
204
|
+
self.scaler = GradScaler()
|
|
205
|
+
if self.loss_fn is not None:
|
|
206
|
+
self.model.set_loss_fn(self.loss_fn)
|
|
207
|
+
|
|
208
|
+
self.model.to(self.device)
|
|
209
|
+
|
|
210
|
+
self.metadata = env_meta_info()
|
|
211
|
+
self.metrics = {}
|
|
212
|
+
|
|
213
|
+
self._optimization_direction = None
|
|
214
|
+
self.trial_name = kwargs.get("trial_name", self.model.__class__.__name__)
|
|
215
|
+
|
|
216
|
+
self.predictions = {}
|
|
217
|
+
|
|
218
|
+
def _is_metric_better(self, metrics, stage="valid"):
|
|
219
|
+
"""
|
|
220
|
+
Check if the current metrics are better than the best metrics so far.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
metrics (dict): Current metric values
|
|
224
|
+
stage (str): Stage name ("valid" or "test")
|
|
225
|
+
|
|
226
|
+
Returns:
|
|
227
|
+
bool: True if current metrics are better than best metrics
|
|
228
|
+
"""
|
|
229
|
+
assert stage in [
|
|
230
|
+
"valid",
|
|
231
|
+
"test",
|
|
232
|
+
], "The metrics stage should be either 'valid' or 'test'."
|
|
233
|
+
|
|
234
|
+
prev_metrics = self.metrics.get(stage, None)
|
|
235
|
+
if stage not in self.metrics:
|
|
236
|
+
self.metrics.update({f"{stage}": [metrics]})
|
|
237
|
+
else:
|
|
238
|
+
self.metrics[f"{stage}"].append(metrics)
|
|
239
|
+
|
|
240
|
+
if "best_valid" not in self.metrics:
|
|
241
|
+
self.metrics.update({"best_valid": metrics})
|
|
242
|
+
return True
|
|
243
|
+
|
|
244
|
+
if prev_metrics is None:
|
|
245
|
+
return False
|
|
246
|
+
|
|
247
|
+
self._optimization_direction = (
|
|
248
|
+
_infer_optimization_direction(metrics, prev_metrics)
|
|
249
|
+
if self._optimization_direction is None
|
|
250
|
+
else self._optimization_direction
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
if self._optimization_direction == "larger_is_better":
|
|
254
|
+
if np.mean(list(metrics.values())[0]) > np.mean(
|
|
255
|
+
list(self.metrics["best_valid"].values())[0]
|
|
256
|
+
):
|
|
257
|
+
self.metrics.update({"best_valid": metrics})
|
|
258
|
+
return True
|
|
259
|
+
elif self._optimization_direction == "smaller_is_better":
|
|
260
|
+
if np.mean(list(metrics.values())[0]) < np.mean(
|
|
261
|
+
list(self.metrics["best_valid"].values())[0]
|
|
262
|
+
):
|
|
263
|
+
self.metrics.update({"best_valid": metrics})
|
|
264
|
+
return True
|
|
265
|
+
|
|
266
|
+
return False
|
|
267
|
+
|
|
268
|
+
def train(self, path_to_save=None, **kwargs):
|
|
269
|
+
"""
|
|
270
|
+
Train the model.
|
|
271
|
+
|
|
272
|
+
Args:
|
|
273
|
+
path_to_save (str, optional): Path to save the best model
|
|
274
|
+
**kwargs: Additional keyword arguments
|
|
275
|
+
|
|
276
|
+
Returns:
|
|
277
|
+
dict: Training metrics and results
|
|
278
|
+
"""
|
|
279
|
+
seed_everything(self.seed)
|
|
280
|
+
patience = 0
|
|
281
|
+
|
|
282
|
+
if self.eval_loader is not None and len(self.eval_loader) > 0:
|
|
283
|
+
valid_metrics = self.evaluate()
|
|
284
|
+
else:
|
|
285
|
+
valid_metrics = self.test()
|
|
286
|
+
if self._is_metric_better(valid_metrics, stage="valid"):
|
|
287
|
+
self._save_state_dict()
|
|
288
|
+
patience = 0
|
|
289
|
+
|
|
290
|
+
for epoch in range(self.epochs):
|
|
291
|
+
self.model.train()
|
|
292
|
+
train_loss = []
|
|
293
|
+
train_it = tqdm(
|
|
294
|
+
self.train_loader, desc=f"Epoch {epoch + 1}/{self.epochs} Loss"
|
|
295
|
+
)
|
|
296
|
+
for step, batch in enumerate(train_it):
|
|
297
|
+
batch = batch.to(self.device)
|
|
298
|
+
|
|
299
|
+
if step % self.gradient_accumulation_steps == 0:
|
|
300
|
+
self.optimizer.zero_grad()
|
|
301
|
+
|
|
302
|
+
if self.fast_dtype:
|
|
303
|
+
with torch.autocast(device_type=self.device.type, dtype=self.fast_dtype):
|
|
304
|
+
outputs = self.model(**batch)
|
|
305
|
+
else:
|
|
306
|
+
outputs = self.model(**batch)
|
|
307
|
+
if "loss" not in outputs:
|
|
308
|
+
# Generally, the model should return a loss in the outputs via OmniGenBench
|
|
309
|
+
# For the Lora models, the loss is computed separately
|
|
310
|
+
if hasattr(self.model, "loss_function") and callable(self.model.loss_function):
|
|
311
|
+
loss = self.model.loss_function(outputs['logits'], outputs["labels"])
|
|
312
|
+
elif (hasattr(self.model, "model")
|
|
313
|
+
and hasattr(self.model.model, "loss_function")
|
|
314
|
+
and callable(self.model.model.loss_function)):
|
|
315
|
+
loss = self.model.model.loss_function(outputs['logits'], outputs["labels"])
|
|
316
|
+
else:
|
|
317
|
+
raise ValueError(
|
|
318
|
+
"The model does not have a loss function defined. "
|
|
319
|
+
"Please provide a loss function or ensure the model has one."
|
|
320
|
+
)
|
|
321
|
+
else:
|
|
322
|
+
# If the model returns a loss directly
|
|
323
|
+
loss = outputs["loss"]
|
|
324
|
+
|
|
325
|
+
loss = loss / self.gradient_accumulation_steps
|
|
326
|
+
|
|
327
|
+
if self.fast_dtype:
|
|
328
|
+
self.scaler.scale(loss).backward()
|
|
329
|
+
else:
|
|
330
|
+
loss.backward()
|
|
331
|
+
|
|
332
|
+
if (step + 1) % self.gradient_accumulation_steps == 0 or (
|
|
333
|
+
step + 1
|
|
334
|
+
) == len(self.train_loader):
|
|
335
|
+
if self.fast_dtype:
|
|
336
|
+
self.scaler.step(self.optimizer)
|
|
337
|
+
self.scaler.update()
|
|
338
|
+
else:
|
|
339
|
+
self.optimizer.step()
|
|
340
|
+
|
|
341
|
+
train_loss.append(loss.item() * self.gradient_accumulation_steps)
|
|
342
|
+
train_it.set_description(
|
|
343
|
+
f"Epoch {epoch + 1}/{self.epochs} Loss: {np.nanmean(train_loss):.4f}"
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
if self.eval_loader is not None and len(self.eval_loader) > 0:
|
|
347
|
+
valid_metrics = self.evaluate()
|
|
348
|
+
else:
|
|
349
|
+
valid_metrics = self.test()
|
|
350
|
+
|
|
351
|
+
if self._is_metric_better(valid_metrics, stage="valid"):
|
|
352
|
+
self._save_state_dict()
|
|
353
|
+
patience = 0
|
|
354
|
+
else:
|
|
355
|
+
patience += 1
|
|
356
|
+
if patience >= self.patience:
|
|
357
|
+
fprint(f"Early stopping at epoch {epoch + 1}.")
|
|
358
|
+
break
|
|
359
|
+
|
|
360
|
+
if path_to_save:
|
|
361
|
+
_path_to_save = path_to_save + "_epoch_" + str(epoch + 1)
|
|
362
|
+
|
|
363
|
+
if valid_metrics:
|
|
364
|
+
for key, value in valid_metrics.items():
|
|
365
|
+
_path_to_save += f"_seed_{self.seed}_{key}_{value:.4f}"
|
|
366
|
+
|
|
367
|
+
self.save_model(_path_to_save, **kwargs)
|
|
368
|
+
|
|
369
|
+
if self.test_loader is not None and len(self.test_loader) > 0:
|
|
370
|
+
self._load_state_dict()
|
|
371
|
+
test_metrics = self.test()
|
|
372
|
+
self._is_metric_better(test_metrics, stage="test")
|
|
373
|
+
|
|
374
|
+
if path_to_save:
|
|
375
|
+
_path_to_save = path_to_save + "_final"
|
|
376
|
+
if self.metrics["test"]:
|
|
377
|
+
for key, value in self.metrics["test"][-1].items():
|
|
378
|
+
_path_to_save += f"_seed_{self.seed}_{key}_{value:.4f}"
|
|
379
|
+
|
|
380
|
+
self.save_model(_path_to_save, **kwargs)
|
|
381
|
+
|
|
382
|
+
self._remove_state_dict()
|
|
383
|
+
|
|
384
|
+
return self.metrics
|
|
385
|
+
|
|
386
|
+
def evaluate(self):
|
|
387
|
+
"""
|
|
388
|
+
Evaluate the model on the validation set.
|
|
389
|
+
|
|
390
|
+
Returns:
|
|
391
|
+
dict: Evaluation metrics
|
|
392
|
+
"""
|
|
393
|
+
with torch.no_grad():
|
|
394
|
+
self.model.eval()
|
|
395
|
+
val_truth = []
|
|
396
|
+
val_preds = []
|
|
397
|
+
it = tqdm(self.eval_loader, desc="Evaluating")
|
|
398
|
+
for batch in it:
|
|
399
|
+
batch.to(self.device)
|
|
400
|
+
labels = batch["labels"]
|
|
401
|
+
batch.pop("labels")
|
|
402
|
+
if self.fast_dtype:
|
|
403
|
+
with torch.autocast(device_type="cuda", dtype=self.fast_dtype):
|
|
404
|
+
predictions = self.model.predict(batch)["predictions"]
|
|
405
|
+
else:
|
|
406
|
+
predictions = self.model.predict(batch)["predictions"]
|
|
407
|
+
val_truth.append(labels.float().cpu().numpy(force=True))
|
|
408
|
+
val_preds.append(predictions.float().cpu().numpy(force=True))
|
|
409
|
+
val_truth = (
|
|
410
|
+
np.vstack(val_truth) if labels.ndim > 1 else np.hstack(val_truth)
|
|
411
|
+
)
|
|
412
|
+
val_preds = (
|
|
413
|
+
np.vstack(val_preds) if predictions.ndim > 1 else np.hstack(val_preds)
|
|
414
|
+
)
|
|
415
|
+
if not np.all(val_truth == -100):
|
|
416
|
+
valid_metrics = {}
|
|
417
|
+
for metric_func in self.compute_metrics:
|
|
418
|
+
valid_metrics.update(metric_func(val_truth, val_preds))
|
|
419
|
+
|
|
420
|
+
fprint(valid_metrics)
|
|
421
|
+
else:
|
|
422
|
+
valid_metrics = {
|
|
423
|
+
"Validation set labels may be NaN. No metrics calculated.": 0
|
|
424
|
+
}
|
|
425
|
+
|
|
426
|
+
self.predictions.update({"valid": {"pred": val_preds, "true": val_truth}})
|
|
427
|
+
|
|
428
|
+
return valid_metrics
|
|
429
|
+
|
|
430
|
+
def test(self):
|
|
431
|
+
"""
|
|
432
|
+
Test the model on the test set.
|
|
433
|
+
|
|
434
|
+
Returns:
|
|
435
|
+
dict: Test metrics and predictions
|
|
436
|
+
"""
|
|
437
|
+
with torch.no_grad():
|
|
438
|
+
self.model.eval()
|
|
439
|
+
preds = []
|
|
440
|
+
truth = []
|
|
441
|
+
it = tqdm(self.test_loader, desc="Testing")
|
|
442
|
+
for batch in it:
|
|
443
|
+
batch.to(self.device)
|
|
444
|
+
labels = batch["labels"]
|
|
445
|
+
batch.pop("labels")
|
|
446
|
+
if self.fast_dtype:
|
|
447
|
+
with torch.autocast(device_type="cuda", dtype=self.fast_dtype):
|
|
448
|
+
predictions = self.model.predict(batch)["predictions"]
|
|
449
|
+
else:
|
|
450
|
+
predictions = self.model.predict(batch)["predictions"]
|
|
451
|
+
truth.append(labels.float().cpu().numpy(force=True))
|
|
452
|
+
preds.append(predictions.float().cpu().numpy(force=True))
|
|
453
|
+
truth = np.vstack(truth) if labels.ndim > 1 else np.hstack(truth)
|
|
454
|
+
preds = np.vstack(preds) if predictions.ndim > 1 else np.hstack(preds)
|
|
455
|
+
if not np.all(truth == -100):
|
|
456
|
+
test_metrics = {}
|
|
457
|
+
for metric_func in self.compute_metrics:
|
|
458
|
+
test_metrics.update(metric_func(truth, preds))
|
|
459
|
+
|
|
460
|
+
fprint(test_metrics)
|
|
461
|
+
else:
|
|
462
|
+
test_metrics = {"Test set labels may be NaN. No metrics calculated.": 0}
|
|
463
|
+
|
|
464
|
+
self.predictions.update({"test": {"pred": preds, "true": truth}})
|
|
465
|
+
|
|
466
|
+
return test_metrics
|
|
467
|
+
|
|
468
|
+
def predict(self, data_loader):
|
|
469
|
+
"""
|
|
470
|
+
Generate predictions using the model.
|
|
471
|
+
|
|
472
|
+
Args:
|
|
473
|
+
data_loader: DataLoader for prediction data
|
|
474
|
+
|
|
475
|
+
Returns:
|
|
476
|
+
torch.Tensor: Model predictions
|
|
477
|
+
"""
|
|
478
|
+
return self.model.predict(data_loader)
|
|
479
|
+
|
|
480
|
+
def get_model(self, **kwargs):
|
|
481
|
+
"""
|
|
482
|
+
Get the trained model.
|
|
483
|
+
|
|
484
|
+
Args:
|
|
485
|
+
**kwargs: Additional keyword arguments
|
|
486
|
+
|
|
487
|
+
Returns:
|
|
488
|
+
The trained model
|
|
489
|
+
"""
|
|
490
|
+
return self.model
|
|
491
|
+
|
|
492
|
+
def compute_metrics(self):
|
|
493
|
+
"""
|
|
494
|
+
Get the metric computation functions.
|
|
495
|
+
|
|
496
|
+
Returns:
|
|
497
|
+
list: List of metric computation functions
|
|
498
|
+
"""
|
|
499
|
+
return self.compute_metrics
|
|
500
|
+
|
|
501
|
+
def unwrap_model(self, model=None):
|
|
502
|
+
"""
|
|
503
|
+
Unwrap the model from any distributed training wrappers.
|
|
504
|
+
|
|
505
|
+
Args:
|
|
506
|
+
model: Model to unwrap (default: None, uses self.model)
|
|
507
|
+
|
|
508
|
+
Returns:
|
|
509
|
+
The unwrapped model
|
|
510
|
+
"""
|
|
511
|
+
if model is None:
|
|
512
|
+
model = self.model
|
|
513
|
+
try:
|
|
514
|
+
return self.accelerator.unwrap_model(model)
|
|
515
|
+
except:
|
|
516
|
+
try:
|
|
517
|
+
return model.module
|
|
518
|
+
except:
|
|
519
|
+
return model
|
|
520
|
+
|
|
521
|
+
def save_model(self, path, overwrite=False, **kwargs):
|
|
522
|
+
"""
|
|
523
|
+
Save the model to disk.
|
|
524
|
+
|
|
525
|
+
Args:
|
|
526
|
+
path (str): Path to save the model
|
|
527
|
+
overwrite (bool): Whether to overwrite existing files (default: False)
|
|
528
|
+
**kwargs: Additional keyword arguments
|
|
529
|
+
"""
|
|
530
|
+
self.unwrap_model().save(path, overwrite, **kwargs)
|
|
531
|
+
|
|
532
|
+
def _load_state_dict(self):
|
|
533
|
+
"""
|
|
534
|
+
Load model state dictionary from temporary file.
|
|
535
|
+
|
|
536
|
+
Returns:
|
|
537
|
+
dict: Model state dictionary
|
|
538
|
+
"""
|
|
539
|
+
if os.path.exists(self._model_state_dict_path):
|
|
540
|
+
self.unwrap_model().load_state_dict(
|
|
541
|
+
torch.load(self._model_state_dict_path, map_location='cpu')
|
|
542
|
+
)
|
|
543
|
+
self.unwrap_model().to(self.device)
|
|
544
|
+
|
|
545
|
+
def _save_state_dict(self):
|
|
546
|
+
"""
|
|
547
|
+
Save model state dictionary to temporary file.
|
|
548
|
+
|
|
549
|
+
Returns:
|
|
550
|
+
str: Path to temporary file
|
|
551
|
+
"""
|
|
552
|
+
if not hasattr(self, "_model_state_dict_path"):
|
|
553
|
+
# 创建临时文件,并关闭以便写入
|
|
554
|
+
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pt")
|
|
555
|
+
self._model_state_dict_path = tmp_file.name
|
|
556
|
+
tmp_file.close()
|
|
557
|
+
|
|
558
|
+
try:
|
|
559
|
+
if os.path.exists(self._model_state_dict_path):
|
|
560
|
+
os.remove(self._model_state_dict_path)
|
|
561
|
+
except Exception as e:
|
|
562
|
+
fprint(
|
|
563
|
+
f"Failed to remove the temporary checkpoint file {self._model_state_dict_path}: {e}"
|
|
564
|
+
)
|
|
565
|
+
|
|
566
|
+
torch.save(self.unwrap_model().state_dict(), self._model_state_dict_path)
|
|
567
|
+
|
|
568
|
+
def _remove_state_dict(self):
|
|
569
|
+
"""
|
|
570
|
+
Remove temporary state dictionary file.
|
|
571
|
+
"""
|
|
572
|
+
if hasattr(self, "_model_state_dict_path"):
|
|
573
|
+
try:
|
|
574
|
+
if os.path.exists(self._model_state_dict_path):
|
|
575
|
+
os.remove(self._model_state_dict_path)
|
|
576
|
+
except Exception as e:
|
|
577
|
+
fprint(
|
|
578
|
+
f"Failed to remove the temporary checkpoint file {self._model_state_dict_path}: {e}"
|
|
579
|
+
)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# File: __init__.py
|
|
3
|
+
# Time: 02:22 20/06/2025
|
|
4
|
+
# Author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
+
# Website: https://yangheng95.github.io
|
|
6
|
+
# GitHub: https://github.com/yangheng95
|
|
7
|
+
# HuggingFace: https://huggingface.co/yangheng
|
|
8
|
+
# Google Scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
9
|
+
# Copyright (C) 2019-2025. All rights reserved.
|
|
10
|
+
"""
|
|
11
|
+
This package contains modules for the dataset hub.
|
|
12
|
+
"""
|
|
13
|
+
|