omnigenome 0.3.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of omnigenome might be problematic. Click here for more details.
- omnigenome/__init__.py +281 -0
- omnigenome/auto/__init__.py +3 -0
- omnigenome/auto/auto_bench/__init__.py +12 -0
- omnigenome/auto/auto_bench/auto_bench.py +484 -0
- omnigenome/auto/auto_bench/auto_bench_cli.py +230 -0
- omnigenome/auto/auto_bench/auto_bench_config.py +216 -0
- omnigenome/auto/auto_bench/config_check.py +34 -0
- omnigenome/auto/auto_train/__init__.py +13 -0
- omnigenome/auto/auto_train/auto_train.py +430 -0
- omnigenome/auto/auto_train/auto_train_cli.py +222 -0
- omnigenome/auto/bench_hub/__init__.py +12 -0
- omnigenome/auto/bench_hub/bench_hub.py +25 -0
- omnigenome/cli/__init__.py +13 -0
- omnigenome/cli/commands/__init__.py +13 -0
- omnigenome/cli/commands/base.py +83 -0
- omnigenome/cli/commands/bench/__init__.py +13 -0
- omnigenome/cli/commands/bench/bench_cli.py +202 -0
- omnigenome/cli/commands/rna/__init__.py +13 -0
- omnigenome/cli/commands/rna/rna_design.py +178 -0
- omnigenome/cli/omnigenome_cli.py +128 -0
- omnigenome/src/__init__.py +12 -0
- omnigenome/src/abc/__init__.py +12 -0
- omnigenome/src/abc/abstract_dataset.py +622 -0
- omnigenome/src/abc/abstract_metric.py +114 -0
- omnigenome/src/abc/abstract_model.py +689 -0
- omnigenome/src/abc/abstract_tokenizer.py +267 -0
- omnigenome/src/dataset/__init__.py +16 -0
- omnigenome/src/dataset/omni_dataset.py +435 -0
- omnigenome/src/lora/__init__.py +13 -0
- omnigenome/src/lora/lora_model.py +294 -0
- omnigenome/src/metric/__init__.py +15 -0
- omnigenome/src/metric/classification_metric.py +184 -0
- omnigenome/src/metric/metric.py +199 -0
- omnigenome/src/metric/ranking_metric.py +142 -0
- omnigenome/src/metric/regression_metric.py +191 -0
- omnigenome/src/misc/__init__.py +3 -0
- omnigenome/src/misc/utils.py +439 -0
- omnigenome/src/model/__init__.py +19 -0
- omnigenome/src/model/augmentation/__init__.py +12 -0
- omnigenome/src/model/augmentation/model.py +219 -0
- omnigenome/src/model/classification/__init__.py +12 -0
- omnigenome/src/model/classification/model.py +642 -0
- omnigenome/src/model/embedding/__init__.py +12 -0
- omnigenome/src/model/embedding/model.py +263 -0
- omnigenome/src/model/mlm/__init__.py +12 -0
- omnigenome/src/model/mlm/model.py +177 -0
- omnigenome/src/model/module_utils.py +232 -0
- omnigenome/src/model/regression/__init__.py +12 -0
- omnigenome/src/model/regression/model.py +786 -0
- omnigenome/src/model/regression/resnet.py +483 -0
- omnigenome/src/model/rna_design/__init__.py +12 -0
- omnigenome/src/model/rna_design/model.py +426 -0
- omnigenome/src/model/seq2seq/__init__.py +12 -0
- omnigenome/src/model/seq2seq/model.py +44 -0
- omnigenome/src/tokenizer/__init__.py +16 -0
- omnigenome/src/tokenizer/bpe_tokenizer.py +226 -0
- omnigenome/src/tokenizer/kmers_tokenizer.py +247 -0
- omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +249 -0
- omnigenome/src/trainer/__init__.py +14 -0
- omnigenome/src/trainer/accelerate_trainer.py +739 -0
- omnigenome/src/trainer/hf_trainer.py +75 -0
- omnigenome/src/trainer/trainer.py +579 -0
- omnigenome/utility/__init__.py +3 -0
- omnigenome/utility/dataset_hub/__init__.py +13 -0
- omnigenome/utility/dataset_hub/dataset_hub.py +178 -0
- omnigenome/utility/ensemble.py +324 -0
- omnigenome/utility/hub_utils.py +517 -0
- omnigenome/utility/model_hub/__init__.py +12 -0
- omnigenome/utility/model_hub/model_hub.py +231 -0
- omnigenome/utility/pipeline_hub/__init__.py +12 -0
- omnigenome/utility/pipeline_hub/pipeline.py +483 -0
- omnigenome/utility/pipeline_hub/pipeline_hub.py +129 -0
- omnigenome-0.3.0a0.dist-info/METADATA +224 -0
- omnigenome-0.3.0a0.dist-info/RECORD +85 -0
- omnigenome-0.3.0a0.dist-info/WHEEL +5 -0
- omnigenome-0.3.0a0.dist-info/entry_points.txt +3 -0
- omnigenome-0.3.0a0.dist-info/licenses/LICENSE +201 -0
- omnigenome-0.3.0a0.dist-info/top_level.txt +2 -0
- tests/__init__.py +9 -0
- tests/conftest.py +160 -0
- tests/test_dataset_patterns.py +291 -0
- tests/test_examples_syntax.py +83 -0
- tests/test_model_loading.py +183 -0
- tests/test_rna_functions.py +255 -0
- tests/test_training_patterns.py +302 -0
|
@@ -0,0 +1,430 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# file: auto_bench.py
|
|
3
|
+
# time: 11:54 14/04/2024
|
|
4
|
+
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
+
# github: https://github.com/yangheng95
|
|
6
|
+
# huggingface: https://huggingface.co/yangheng
|
|
7
|
+
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
8
|
+
# Copyright (C) 2019-2024. All Rights Reserved.
|
|
9
|
+
|
|
10
|
+
import os
|
|
11
|
+
import time
|
|
12
|
+
import warnings
|
|
13
|
+
|
|
14
|
+
import findfile
|
|
15
|
+
import torch
|
|
16
|
+
from metric_visualizer import MetricVisualizer
|
|
17
|
+
from transformers import TrainingArguments, Trainer as HFTrainer
|
|
18
|
+
|
|
19
|
+
from ...src.lora.lora_model import OmniLoraModel
|
|
20
|
+
from ...src.abc.abstract_tokenizer import OmniTokenizer
|
|
21
|
+
from ...src.misc.utils import (
|
|
22
|
+
seed_everything,
|
|
23
|
+
fprint,
|
|
24
|
+
load_module_from_path,
|
|
25
|
+
clean_temp_checkpoint,
|
|
26
|
+
)
|
|
27
|
+
from ...src.trainer.accelerate_trainer import AccelerateTrainer
|
|
28
|
+
from ...src.trainer.trainer import Trainer
|
|
29
|
+
|
|
30
|
+
autotrain_evaluations = "./autotrain_evaluations"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class AutoTrain:
|
|
34
|
+
"""
|
|
35
|
+
AutoTrain is a class for automatically training genomic foundation models on a given dataset.
|
|
36
|
+
|
|
37
|
+
This class provides a comprehensive framework for training genomic models
|
|
38
|
+
on various datasets with minimal configuration. It handles dataset loading,
|
|
39
|
+
model initialization, training configuration, and result tracking.
|
|
40
|
+
|
|
41
|
+
AutoTrain supports various training scenarios including:
|
|
42
|
+
- Single dataset training with multiple seeds
|
|
43
|
+
- Different trainer backends (native, accelerate, huggingface)
|
|
44
|
+
- Automatic metric visualization and result tracking
|
|
45
|
+
- Configurable training parameters
|
|
46
|
+
|
|
47
|
+
Attributes:
|
|
48
|
+
dataset (str): The name or path of the dataset to use for training.
|
|
49
|
+
model_name_or_path (str): The name or path of the model to train.
|
|
50
|
+
tokenizer: The tokenizer to use for training.
|
|
51
|
+
autocast (str): The autocast precision to use ('fp16', 'bf16', etc.).
|
|
52
|
+
overwrite (bool): Whether to overwrite existing training results.
|
|
53
|
+
trainer (str): The trainer to use ('native', 'accelerate', 'hf_trainer').
|
|
54
|
+
mv_path (str): Path to the metric visualizer file.
|
|
55
|
+
mv (MetricVisualizer): The metric visualizer instance.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
dataset,
|
|
61
|
+
model_name_or_path,
|
|
62
|
+
tokenizer=None,
|
|
63
|
+
**kwargs,
|
|
64
|
+
):
|
|
65
|
+
"""
|
|
66
|
+
Initialize the AutoTrain instance.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
dataset (str): The name or path of the dataset to use for training.
|
|
70
|
+
model_name_or_path (str): The name or path of the model to train.
|
|
71
|
+
tokenizer: The tokenizer to use. If None, it will be loaded from the model path.
|
|
72
|
+
**kwargs: Additional keyword arguments.
|
|
73
|
+
- autocast (str): The autocast precision to use ('fp16', 'bf16', etc.).
|
|
74
|
+
Defaults to 'fp16'.
|
|
75
|
+
- overwrite (bool): Whether to overwrite existing training results.
|
|
76
|
+
Defaults to False.
|
|
77
|
+
- trainer (str): The trainer to use ('native', 'accelerate', 'hf_trainer').
|
|
78
|
+
Defaults to 'accelerate'.
|
|
79
|
+
|
|
80
|
+
Example:
|
|
81
|
+
>>> # Initialize with a dataset and model
|
|
82
|
+
>>> trainer = AutoTrain("dataset_name", "model_name")
|
|
83
|
+
|
|
84
|
+
>>> # Initialize with custom settings
|
|
85
|
+
>>> trainer = AutoTrain("dataset_name", "model_name",
|
|
86
|
+
... autocast="bf16", trainer="accelerate")
|
|
87
|
+
"""
|
|
88
|
+
self.dataset = dataset.rstrip("/")
|
|
89
|
+
self.autocast = kwargs.pop("autocast", "fp16")
|
|
90
|
+
self.overwrite = kwargs.pop("overwrite", False)
|
|
91
|
+
self.trainer = kwargs.pop("trainer", "accelerate")
|
|
92
|
+
|
|
93
|
+
self.model_name_or_path = model_name_or_path
|
|
94
|
+
self.tokenizer = tokenizer
|
|
95
|
+
if isinstance(self.model_name_or_path, str):
|
|
96
|
+
self.model_name_or_path = self.model_name_or_path.rstrip("/")
|
|
97
|
+
self.model_name = self.model_name_or_path.split("/")[-1]
|
|
98
|
+
else:
|
|
99
|
+
self.model_name = self.model_name_or_path.__class__.__name__
|
|
100
|
+
if isinstance(tokenizer, str):
|
|
101
|
+
self.tokenizer = tokenizer.rstrip("/")
|
|
102
|
+
os.makedirs(autotrain_evaluations, exist_ok=True)
|
|
103
|
+
time_str = time.strftime("%Y%m%d_%H%M%S", time.localtime())
|
|
104
|
+
mv_name = f"{dataset}-{self.model_name}"
|
|
105
|
+
self.mv_path = f"{autotrain_evaluations}/{mv_name}-{time_str}.mv"
|
|
106
|
+
|
|
107
|
+
mv_paths = findfile.find_files(
|
|
108
|
+
autotrain_evaluations,
|
|
109
|
+
[dataset, self.model_name, ".mv"],
|
|
110
|
+
)
|
|
111
|
+
if mv_paths and not self.overwrite:
|
|
112
|
+
self.mv = MetricVisualizer.load(mv_paths[-1])
|
|
113
|
+
self.mv.summary(round=4)
|
|
114
|
+
else:
|
|
115
|
+
self.mv = MetricVisualizer(self.mv_path)
|
|
116
|
+
self.bench_info()
|
|
117
|
+
|
|
118
|
+
def bench_info(self):
|
|
119
|
+
"""
|
|
120
|
+
Print and return information about the current training setup.
|
|
121
|
+
|
|
122
|
+
This method provides a comprehensive overview of the current
|
|
123
|
+
training configuration, including dataset details, model information,
|
|
124
|
+
and training settings.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
str: A string containing training setup information.
|
|
128
|
+
|
|
129
|
+
Example:
|
|
130
|
+
>>> info = trainer.bench_info()
|
|
131
|
+
>>> print(info)
|
|
132
|
+
"""
|
|
133
|
+
info = f"Dataset Root: {self.dataset}\n"
|
|
134
|
+
info += f"Model Name or Path: {self.model_name}\n"
|
|
135
|
+
info += f"Tokenizer: {self.tokenizer}\n"
|
|
136
|
+
info += f"Metric Visualizer Path: {self.mv_path}\n"
|
|
137
|
+
fprint(info)
|
|
138
|
+
return info
|
|
139
|
+
|
|
140
|
+
def run(self, **kwargs):
|
|
141
|
+
"""
|
|
142
|
+
Run the training process.
|
|
143
|
+
|
|
144
|
+
This method loads the dataset configuration, initializes the model and
|
|
145
|
+
tokenizer, and runs training across multiple seeds. It supports various
|
|
146
|
+
training backends and automatic result tracking.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
**kwargs: Additional keyword arguments that will override the default
|
|
150
|
+
parameters in the dataset configuration.
|
|
151
|
+
|
|
152
|
+
Example:
|
|
153
|
+
>>> # Run training with default settings
|
|
154
|
+
>>> trainer.run()
|
|
155
|
+
|
|
156
|
+
>>> # Run with custom parameters
|
|
157
|
+
>>> trainer.run(learning_rate=1e-4, batch_size=16)
|
|
158
|
+
"""
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
clean_temp_checkpoint(1) # clean temp checkpoint older than 1 day
|
|
162
|
+
|
|
163
|
+
_kwargs = kwargs.copy()
|
|
164
|
+
bench_config_path = findfile.find_file(
|
|
165
|
+
self.dataset, f"{self.dataset}.config".split(".")
|
|
166
|
+
)
|
|
167
|
+
config = load_module_from_path("config", bench_config_path)
|
|
168
|
+
bench_config = config.bench_config
|
|
169
|
+
fprint(f"Loaded config for {self.dataset} from {bench_config_path}")
|
|
170
|
+
fprint(bench_config)
|
|
171
|
+
|
|
172
|
+
# Init Tokenizer and Model
|
|
173
|
+
if not self.tokenizer:
|
|
174
|
+
tokenizer = OmniTokenizer.from_pretrained(
|
|
175
|
+
self.model_name_or_path, trust_remote_code=True
|
|
176
|
+
)
|
|
177
|
+
else:
|
|
178
|
+
tokenizer = self.tokenizer
|
|
179
|
+
|
|
180
|
+
if not isinstance(bench_config["seeds"], list):
|
|
181
|
+
bench_config["seeds"] = [bench_config["seeds"]]
|
|
182
|
+
|
|
183
|
+
random_seeds = bench_config["seeds"]
|
|
184
|
+
for seed in random_seeds:
|
|
185
|
+
for key, value in _kwargs.items():
|
|
186
|
+
if key in bench_config:
|
|
187
|
+
fprint(
|
|
188
|
+
"Override", key, "with", value, "according to the input kwargs"
|
|
189
|
+
)
|
|
190
|
+
bench_config.update({key: value})
|
|
191
|
+
|
|
192
|
+
else:
|
|
193
|
+
warnings.warn(
|
|
194
|
+
f"kwarg: {key} not found in bench_config while setting {key} = {value}"
|
|
195
|
+
)
|
|
196
|
+
bench_config.update({key: value})
|
|
197
|
+
|
|
198
|
+
for key, value in bench_config.items():
|
|
199
|
+
if key in bench_config and key in _kwargs:
|
|
200
|
+
_kwargs.pop(key)
|
|
201
|
+
|
|
202
|
+
fprint(
|
|
203
|
+
f"AutoBench Config for {self.dataset}:",
|
|
204
|
+
"\n".join([f"{k}: {v}" for k, v in bench_config.items()]),
|
|
205
|
+
)
|
|
206
|
+
for key, value in _kwargs.items():
|
|
207
|
+
if key in bench_config:
|
|
208
|
+
fprint(
|
|
209
|
+
"Override", key, "with", value, "according to the input kwargs"
|
|
210
|
+
)
|
|
211
|
+
bench_config.update({key: value})
|
|
212
|
+
|
|
213
|
+
else:
|
|
214
|
+
warnings.warn(
|
|
215
|
+
f"kwarg: {key} not found in bench_config while setting {key} = {value}"
|
|
216
|
+
)
|
|
217
|
+
bench_config.update({key: value})
|
|
218
|
+
|
|
219
|
+
for key, value in bench_config.items():
|
|
220
|
+
if key in bench_config and key in _kwargs:
|
|
221
|
+
_kwargs.pop(key)
|
|
222
|
+
|
|
223
|
+
fprint(
|
|
224
|
+
f"AutoBench Config for {self.dataset}:",
|
|
225
|
+
"\n".join([f"{k}: {v}" for k, v in bench_config.items()]),
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
batch_size = (
|
|
229
|
+
bench_config["batch_size"] if "batch_size" in bench_config else 8
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
record_name = f"{os.path.basename(self.dataset)}-{self.model_name}".split(
|
|
233
|
+
"/"
|
|
234
|
+
)[-1]
|
|
235
|
+
# check if the record exists
|
|
236
|
+
if record_name in self.mv.transpose() and len(
|
|
237
|
+
list(self.mv.transpose()[record_name].values())[0]
|
|
238
|
+
) >= len(bench_config["seeds"]):
|
|
239
|
+
continue
|
|
240
|
+
|
|
241
|
+
seed_everything(seed)
|
|
242
|
+
if self.model_name_or_path:
|
|
243
|
+
model_cls = bench_config["model_cls"]
|
|
244
|
+
model = model_cls(
|
|
245
|
+
self.model_name_or_path,
|
|
246
|
+
tokenizer=tokenizer,
|
|
247
|
+
label2id=bench_config.label2id,
|
|
248
|
+
num_labels=bench_config["num_labels"],
|
|
249
|
+
trust_remote_code=True,
|
|
250
|
+
ignore_mismatched_sizes=True,
|
|
251
|
+
)
|
|
252
|
+
else:
|
|
253
|
+
raise ValueError(
|
|
254
|
+
"model_name_or_path is not specified. Please provide a valid model name or path."
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
if kwargs.get("lora_config", None) is not None:
|
|
258
|
+
fprint("Applying LoRA to the model with config:", kwargs["lora_config"])
|
|
259
|
+
model = OmniLoraModel(model, **kwargs.get("lora_config", {}))
|
|
260
|
+
|
|
261
|
+
# Init Trainer
|
|
262
|
+
dataset_cls = bench_config["dataset_cls"]
|
|
263
|
+
|
|
264
|
+
if hasattr(model.config, "max_position_embeddings"):
|
|
265
|
+
max_length = min(
|
|
266
|
+
bench_config["max_length"],
|
|
267
|
+
model.config.max_position_embeddings,
|
|
268
|
+
)
|
|
269
|
+
else:
|
|
270
|
+
max_length = bench_config["max_length"]
|
|
271
|
+
|
|
272
|
+
train_set = dataset_cls(
|
|
273
|
+
data_source=bench_config["train_file"],
|
|
274
|
+
tokenizer=tokenizer,
|
|
275
|
+
label2id=bench_config["label2id"],
|
|
276
|
+
max_length=max_length,
|
|
277
|
+
structure_in=bench_config.get("structure_in", False),
|
|
278
|
+
max_examples=bench_config.get("max_examples", None),
|
|
279
|
+
shuffle=bench_config.get("shuffle", True),
|
|
280
|
+
drop_long_seq=bench_config.get("drop_long_seq", False),
|
|
281
|
+
**_kwargs,
|
|
282
|
+
)
|
|
283
|
+
test_set = dataset_cls(
|
|
284
|
+
data_source=bench_config["test_file"],
|
|
285
|
+
tokenizer=tokenizer,
|
|
286
|
+
label2id=bench_config["label2id"],
|
|
287
|
+
max_length=max_length,
|
|
288
|
+
structure_in=bench_config.get("structure_in", False),
|
|
289
|
+
max_examples=bench_config.get("max_examples", None),
|
|
290
|
+
shuffle=False,
|
|
291
|
+
drop_long_seq=bench_config.get("drop_long_seq", False),
|
|
292
|
+
**_kwargs,
|
|
293
|
+
)
|
|
294
|
+
valid_set = dataset_cls(
|
|
295
|
+
data_source=bench_config["valid_file"],
|
|
296
|
+
tokenizer=tokenizer,
|
|
297
|
+
label2id=bench_config["label2id"],
|
|
298
|
+
max_length=max_length,
|
|
299
|
+
structure_in=bench_config.get("structure_in", False),
|
|
300
|
+
max_examples=bench_config.get("max_examples", None),
|
|
301
|
+
shuffle=False,
|
|
302
|
+
drop_long_seq=bench_config.get("drop_long_seq", False),
|
|
303
|
+
**_kwargs,
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
if self.trainer == "hf_trainer":
|
|
307
|
+
# Set up HuggingFace Trainer
|
|
308
|
+
hf_kwargs = {
|
|
309
|
+
k: v
|
|
310
|
+
for k, v in kwargs.items()
|
|
311
|
+
if hasattr(TrainingArguments, k) and k != "output_dir"
|
|
312
|
+
}
|
|
313
|
+
training_args = TrainingArguments(
|
|
314
|
+
output_dir=f"./autotrain_evaluations/{self.model_name}",
|
|
315
|
+
num_train_epochs=hf_kwargs.pop(
|
|
316
|
+
"num_train_epochs", bench_config["epochs"]
|
|
317
|
+
),
|
|
318
|
+
per_device_train_batch_size=hf_kwargs.pop("batch_size", batch_size),
|
|
319
|
+
per_device_eval_batch_size=hf_kwargs.pop("batch_size", batch_size),
|
|
320
|
+
gradient_accumulation_steps=hf_kwargs.pop(
|
|
321
|
+
"gradient_accumulation_steps", 1
|
|
322
|
+
),
|
|
323
|
+
learning_rate=hf_kwargs.pop("learning_rate", 2e-5),
|
|
324
|
+
weight_decay=hf_kwargs.pop("weight_decay", 0),
|
|
325
|
+
eval_strategy=hf_kwargs.pop("eval_strategy", "epoch"),
|
|
326
|
+
save_strategy=hf_kwargs.pop("save_strategy", "epoch"),
|
|
327
|
+
fp16=hf_kwargs.pop("fp16", True),
|
|
328
|
+
remove_unused_columns=False,
|
|
329
|
+
label_names=["labels"],
|
|
330
|
+
**hf_kwargs,
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
valid_set = valid_set if len(valid_set) else test_set
|
|
334
|
+
|
|
335
|
+
if len(bench_config["compute_metrics"]) > 1:
|
|
336
|
+
fprint(
|
|
337
|
+
"Multiple metrics not supported by HFTrainer, using the first one metric only."
|
|
338
|
+
)
|
|
339
|
+
trainer = HFTrainer(
|
|
340
|
+
model=model,
|
|
341
|
+
args=training_args,
|
|
342
|
+
train_dataset=train_set,
|
|
343
|
+
eval_dataset=valid_set,
|
|
344
|
+
compute_metrics=(
|
|
345
|
+
bench_config["compute_metrics"][0]
|
|
346
|
+
if isinstance(bench_config["compute_metrics"], list)
|
|
347
|
+
else bench_config["compute_metrics"]
|
|
348
|
+
),
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
# Train and evaluate
|
|
352
|
+
eval_result = trainer.evaluate(
|
|
353
|
+
valid_set if len(valid_set) else test_set
|
|
354
|
+
)
|
|
355
|
+
print(eval_result)
|
|
356
|
+
train_result = trainer.train()
|
|
357
|
+
eval_result = trainer.evaluate()
|
|
358
|
+
test_result = trainer.evaluate(test_set if len(test_set) else valid_set)
|
|
359
|
+
|
|
360
|
+
metrics = {
|
|
361
|
+
"train": train_result.metrics,
|
|
362
|
+
"eval": eval_result,
|
|
363
|
+
"test": test_result,
|
|
364
|
+
}
|
|
365
|
+
fprint(metrics)
|
|
366
|
+
else:
|
|
367
|
+
optimizer = torch.optim.AdamW(
|
|
368
|
+
filter(lambda p: p.requires_grad, model.parameters()),
|
|
369
|
+
lr=(
|
|
370
|
+
bench_config["learning_rate"]
|
|
371
|
+
if "learning_rate" in bench_config
|
|
372
|
+
else 2e-5
|
|
373
|
+
),
|
|
374
|
+
weight_decay=(
|
|
375
|
+
bench_config["weight_decay"]
|
|
376
|
+
if "weight_decay" in bench_config
|
|
377
|
+
else 0
|
|
378
|
+
),
|
|
379
|
+
)
|
|
380
|
+
if self.trainer == "accelerate":
|
|
381
|
+
trainer_cls = AccelerateTrainer
|
|
382
|
+
else:
|
|
383
|
+
trainer_cls = Trainer
|
|
384
|
+
fprint(f"Using Trainer: {trainer_cls}")
|
|
385
|
+
trainer = trainer_cls(
|
|
386
|
+
model=model,
|
|
387
|
+
train_dataset=train_set,
|
|
388
|
+
eval_dataset=valid_set,
|
|
389
|
+
test_dataset=test_set,
|
|
390
|
+
batch_size=batch_size,
|
|
391
|
+
patience=(
|
|
392
|
+
bench_config["patience"] if "patience" in bench_config else 3
|
|
393
|
+
),
|
|
394
|
+
epochs=bench_config["epochs"],
|
|
395
|
+
gradient_accumulation_steps=bench_config.get(
|
|
396
|
+
"gradient_accumulation_steps", 1
|
|
397
|
+
),
|
|
398
|
+
optimizer=optimizer,
|
|
399
|
+
loss_fn=(
|
|
400
|
+
bench_config["loss_fn"] if "loss_fn" in bench_config else None
|
|
401
|
+
),
|
|
402
|
+
compute_metrics=bench_config["compute_metrics"],
|
|
403
|
+
seed=seed,
|
|
404
|
+
autocast=self.autocast,
|
|
405
|
+
**_kwargs,
|
|
406
|
+
)
|
|
407
|
+
metrics = trainer.train()
|
|
408
|
+
save_path = os.path.join(
|
|
409
|
+
autotrain_evaluations, self.dataset, self.model_name
|
|
410
|
+
)
|
|
411
|
+
trainer.save_model(save_path)
|
|
412
|
+
|
|
413
|
+
if metrics:
|
|
414
|
+
for key, value in metrics["test"][-1].items():
|
|
415
|
+
try:
|
|
416
|
+
value = float(value)
|
|
417
|
+
except:
|
|
418
|
+
pass # ignore non-float values
|
|
419
|
+
self.mv.log(f"{record_name}", f"{key}", value)
|
|
420
|
+
# for key, value in metrics['test'][-1].items():
|
|
421
|
+
# self.mv.log(f'{record_name}', f'test_{key}', value)
|
|
422
|
+
# for i, valid_metrics in enumerate(metrics["valid"]):
|
|
423
|
+
# for key, value in valid_metrics.items():
|
|
424
|
+
# self.mv.log(f'{record_name}', f'valid_epoch_{i}_{key}', value)
|
|
425
|
+
|
|
426
|
+
self.mv.summary(round=4)
|
|
427
|
+
self.mv.dump(self.mv_path)
|
|
428
|
+
self.mv.to_csv(self.mv_path.replace(".mv", ".csv"))
|
|
429
|
+
del model, trainer, optimizer
|
|
430
|
+
torch.cuda.empty_cache()
|
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# file: auto_bench_cli.py
|
|
3
|
+
# time: 19:18 05/02/2025
|
|
4
|
+
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
+
# Homepage: https://yangheng95.github.io
|
|
6
|
+
# github: https://github.com/yangheng95
|
|
7
|
+
# huggingface: https://huggingface.co/yangheng
|
|
8
|
+
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
9
|
+
# Copyright (C) 2019-2025. All Rights Reserved.
|
|
10
|
+
|
|
11
|
+
import argparse
|
|
12
|
+
import os
|
|
13
|
+
import platform
|
|
14
|
+
import sys
|
|
15
|
+
import time
|
|
16
|
+
|
|
17
|
+
from typing import Optional
|
|
18
|
+
from omnigenome import AutoTrain
|
|
19
|
+
from omnigenome.src.misc.utils import fprint
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def train_command(args: Optional[list] = None):
|
|
23
|
+
"""
|
|
24
|
+
Entry point for the OmniGenome auto-train command-line interface.
|
|
25
|
+
|
|
26
|
+
:param args: A list of command-line arguments. If None, `sys.argv` is used.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
parser = create_parser()
|
|
30
|
+
parsed_args = parser.parse_args(args)
|
|
31
|
+
|
|
32
|
+
model_path = parsed_args.model
|
|
33
|
+
fprint(f"\n>> Starting evaluation for model: {model_path}")
|
|
34
|
+
|
|
35
|
+
# Special handling for multimolecule models
|
|
36
|
+
if "multimolecule" in model_path:
|
|
37
|
+
from multimolecule import RnaTokenizer, AutoModelForTokenPrediction
|
|
38
|
+
|
|
39
|
+
tokenizer = RnaTokenizer.from_pretrained(model_path)
|
|
40
|
+
model = AutoModelForTokenPrediction.from_pretrained(
|
|
41
|
+
model_path, trust_remote_code=True
|
|
42
|
+
).base_model
|
|
43
|
+
else:
|
|
44
|
+
tokenizer = parsed_args.tokenizer
|
|
45
|
+
model = model_path
|
|
46
|
+
|
|
47
|
+
# Initialize AutoTraining
|
|
48
|
+
autobench = AutoTrain(
|
|
49
|
+
dataset=parsed_args.dataset,
|
|
50
|
+
model_name_or_path=model,
|
|
51
|
+
tokenizer=tokenizer,
|
|
52
|
+
overwrite=parsed_args.overwrite,
|
|
53
|
+
trainer=parsed_args.trainer,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
# Run evaluation
|
|
57
|
+
autobench.run(**vars(parsed_args))
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def create_parser() -> argparse.ArgumentParser:
|
|
61
|
+
"""
|
|
62
|
+
Creates the argument parser for the auto-train CLI.
|
|
63
|
+
|
|
64
|
+
:return: An `argparse.ArgumentParser` instance.
|
|
65
|
+
"""
|
|
66
|
+
parser = argparse.ArgumentParser(
|
|
67
|
+
description="Genomic Foundation Model Benchmark Suite (Single Model)",
|
|
68
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
69
|
+
)
|
|
70
|
+
# Required argument
|
|
71
|
+
parser.add_argument(
|
|
72
|
+
"-d",
|
|
73
|
+
"--dataset",
|
|
74
|
+
type=str,
|
|
75
|
+
help="Path to the dataset and training configuration file.",
|
|
76
|
+
)
|
|
77
|
+
parser.add_argument(
|
|
78
|
+
"-t",
|
|
79
|
+
"--tokenizer",
|
|
80
|
+
type=str,
|
|
81
|
+
default=None,
|
|
82
|
+
help="Path to the tokenizer to use (HF tokenizer ID or local path).",
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
parser.add_argument(
|
|
86
|
+
"-m",
|
|
87
|
+
"--model",
|
|
88
|
+
type=str,
|
|
89
|
+
required=True,
|
|
90
|
+
help="Path to the model to evaluate (HF model ID or local path).",
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
# Optional arguments
|
|
94
|
+
parser.add_argument(
|
|
95
|
+
"--overwrite",
|
|
96
|
+
type=bool,
|
|
97
|
+
default=False,
|
|
98
|
+
help="Overwrite existing bench results, otherwise resume from training checkpoint.",
|
|
99
|
+
)
|
|
100
|
+
parser.add_argument(
|
|
101
|
+
"--bs_scale",
|
|
102
|
+
type=int,
|
|
103
|
+
default=1,
|
|
104
|
+
help="Batch size scale factor. To increase GPU memory utilization, set to 2 or 4, etc.",
|
|
105
|
+
)
|
|
106
|
+
parser.add_argument(
|
|
107
|
+
"--trainer",
|
|
108
|
+
type=str,
|
|
109
|
+
default="accelerate",
|
|
110
|
+
choices=["native", "accelerate", "hf_trainer"],
|
|
111
|
+
help="Trainer to use for training. \n"
|
|
112
|
+
"Use 'accelerate' for distributed training. Set to false to disable. "
|
|
113
|
+
"You can use 'accelerate config' to customize behavior.\n"
|
|
114
|
+
"Use 'hf_trainer' for Hugging Face Trainer. \n"
|
|
115
|
+
"Set to 'native' to use native PyTorch training loop.\n",
|
|
116
|
+
)
|
|
117
|
+
parser.add_argument(
|
|
118
|
+
"--autocast",
|
|
119
|
+
type=str,
|
|
120
|
+
default="fp16",
|
|
121
|
+
choices=["fp16", "fp32", "bf16", "fp8", "no"],
|
|
122
|
+
help="Automatic mixed precision training mode.",
|
|
123
|
+
)
|
|
124
|
+
return parser
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def run_train():
|
|
128
|
+
"""
|
|
129
|
+
Wrapper function to run the auto-train command.
|
|
130
|
+
"""
|
|
131
|
+
|
|
132
|
+
fprint("Running AutoTraining, this may take a while, please be patient...")
|
|
133
|
+
fprint("You can find the logs in the 'autobench_logs' directory.")
|
|
134
|
+
fprint("You can find the metrics in the 'autobench_evaluations' directory.")
|
|
135
|
+
fprint(
|
|
136
|
+
"If you don't intend to use accelerate, please add '--trainer native' to the command."
|
|
137
|
+
)
|
|
138
|
+
fprint(
|
|
139
|
+
"If you want to alter accelerate's behavior, please refer to 'accelerate config' command."
|
|
140
|
+
)
|
|
141
|
+
fprint("If you encounter any issues, please report them on the GitHub repository.")
|
|
142
|
+
os.makedirs("autobench_logs", exist_ok=True)
|
|
143
|
+
time_str = time.strftime("%Y-%m-%d-%H-%M-%S")
|
|
144
|
+
log_file = f"autobench_logs/AutoBench-{time_str}.log"
|
|
145
|
+
from pathlib import Path
|
|
146
|
+
|
|
147
|
+
try:
|
|
148
|
+
mixed_precision = sys.argv[sys.argv.index("--autocast") + 1].lower()
|
|
149
|
+
except ValueError:
|
|
150
|
+
mixed_precision = "fp16"
|
|
151
|
+
file_path = Path(__file__).resolve()
|
|
152
|
+
if (
|
|
153
|
+
"--trainer" in sys.argv
|
|
154
|
+
and sys.argv[sys.argv.index("--trainer") + 1].lower() == "native"
|
|
155
|
+
):
|
|
156
|
+
cmd_base = f'python "{file_path}" ' + " ".join(sys.argv[1:])
|
|
157
|
+
else:
|
|
158
|
+
cmd_base = (
|
|
159
|
+
f'accelerate launch --mixed_precision "{mixed_precision}" "{file_path}" '
|
|
160
|
+
+ " ".join(sys.argv[1:])
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
# Use platform-specific tee commands:
|
|
164
|
+
if platform.system() == "Windows":
|
|
165
|
+
# On Windows, use PowerShell's tee-object.
|
|
166
|
+
# The command below launches PowerShell and passes the tee-object command.
|
|
167
|
+
# try:
|
|
168
|
+
# cmd = f"{cmd_base} 2>&1 | powershell -Command Get-Content {log_file} -Wait"
|
|
169
|
+
# except Exception as e:
|
|
170
|
+
# fprint(f"The log file cannot be saved due to Error: {e}")
|
|
171
|
+
# fprint(
|
|
172
|
+
# "If commands not allowed in PowerShell, "
|
|
173
|
+
# "please run 'Set-ExecutionPolicy RemoteSigned' in PowerShell with Admin."
|
|
174
|
+
# )
|
|
175
|
+
cmd = f"{cmd_base} 2>&1"
|
|
176
|
+
else:
|
|
177
|
+
# On Unix-like systems, use the standard tee command.
|
|
178
|
+
cmd = f"{cmd_base} 2>&1 | tee '{log_file}'"
|
|
179
|
+
|
|
180
|
+
# Execute the command.
|
|
181
|
+
sys.exit(os.system(cmd))
|
|
182
|
+
|
|
183
|
+
# # 匹配tqdm进度条的正则表达式(根据实际输出调整)
|
|
184
|
+
# tqdm_pattern = re.compile(r'^.*\d+%\|.*\|\s+\d+/\d+\s+\[.*\]\s*$')
|
|
185
|
+
#
|
|
186
|
+
# last_tqdm_line = ''
|
|
187
|
+
#
|
|
188
|
+
# with open(log_file, 'w', encoding='utf-8') as log_file:
|
|
189
|
+
# # 执行命令并捕获输出流
|
|
190
|
+
# proc = subprocess.Popen(
|
|
191
|
+
# cmd_base,
|
|
192
|
+
# shell=True,
|
|
193
|
+
# stdout=subprocess.PIPE,
|
|
194
|
+
# stderr=subprocess.STDOUT,
|
|
195
|
+
# bufsize=1,
|
|
196
|
+
# universal_newlines=True
|
|
197
|
+
# )
|
|
198
|
+
#
|
|
199
|
+
# # 实时处理输出流
|
|
200
|
+
# for line in proc.stdout:
|
|
201
|
+
# line = line.rstrip() # 移除行尾换行符
|
|
202
|
+
# if tqdm_pattern.match(line):
|
|
203
|
+
# # 更新最后一行tqdm输出
|
|
204
|
+
# last_tqdm_line = line + '\n' # 换行符需要手动添加
|
|
205
|
+
# # 实时显示进度条(覆盖模式)
|
|
206
|
+
# sys.stdout.write('\r' + line)
|
|
207
|
+
# sys.stdout.flush()
|
|
208
|
+
# else:
|
|
209
|
+
# # 写入日志并正常打印
|
|
210
|
+
# log_file.write(line + '\n')
|
|
211
|
+
# print(line)
|
|
212
|
+
#
|
|
213
|
+
# # 命令执行完毕后写入最后一个tqdm进度条
|
|
214
|
+
# if last_tqdm_line:
|
|
215
|
+
# log_file.write(last_tqdm_line)
|
|
216
|
+
# sys.stdout.write('\n') # 最后换行避免覆盖
|
|
217
|
+
#
|
|
218
|
+
# sys.exit(proc.returncode)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
if __name__ == "__main__":
|
|
222
|
+
train_command()
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# file: __init__.py
|
|
3
|
+
# time: 18:28 11/04/2024
|
|
4
|
+
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
+
# github: https://github.com/yangheng95
|
|
6
|
+
# huggingface: https://huggingface.co/yangheng
|
|
7
|
+
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
8
|
+
# Copyright (C) 2019-2024. All Rights Reserved.
|
|
9
|
+
"""
|
|
10
|
+
This package contains modules for the benchmark hub.
|
|
11
|
+
"""
|
|
12
|
+
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# file: bench_hub.py
|
|
3
|
+
# time: 11:53 14/04/2024
|
|
4
|
+
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
+
# github: https://github.com/yangheng95
|
|
6
|
+
# huggingface: https://huggingface.co/yangheng
|
|
7
|
+
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
8
|
+
# Copyright (C) 2019-2024. All Rights Reserved.
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BenchHub:
|
|
12
|
+
"""
|
|
13
|
+
A hub for accessing and managing benchmarks.
|
|
14
|
+
|
|
15
|
+
This class is intended to provide a centralized way to list, download,
|
|
16
|
+
and inspect available benchmarks for OmniGenome.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self):
|
|
20
|
+
"""Initializes the BenchHub."""
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
def run(self):
|
|
24
|
+
"""Placeholder for running functionality related to the benchmark hub."""
|
|
25
|
+
pass
|