omnigenome 0.3.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of omnigenome might be problematic. Click here for more details.
- omnigenome/__init__.py +281 -0
- omnigenome/auto/__init__.py +3 -0
- omnigenome/auto/auto_bench/__init__.py +12 -0
- omnigenome/auto/auto_bench/auto_bench.py +484 -0
- omnigenome/auto/auto_bench/auto_bench_cli.py +230 -0
- omnigenome/auto/auto_bench/auto_bench_config.py +216 -0
- omnigenome/auto/auto_bench/config_check.py +34 -0
- omnigenome/auto/auto_train/__init__.py +13 -0
- omnigenome/auto/auto_train/auto_train.py +430 -0
- omnigenome/auto/auto_train/auto_train_cli.py +222 -0
- omnigenome/auto/bench_hub/__init__.py +12 -0
- omnigenome/auto/bench_hub/bench_hub.py +25 -0
- omnigenome/cli/__init__.py +13 -0
- omnigenome/cli/commands/__init__.py +13 -0
- omnigenome/cli/commands/base.py +83 -0
- omnigenome/cli/commands/bench/__init__.py +13 -0
- omnigenome/cli/commands/bench/bench_cli.py +202 -0
- omnigenome/cli/commands/rna/__init__.py +13 -0
- omnigenome/cli/commands/rna/rna_design.py +178 -0
- omnigenome/cli/omnigenome_cli.py +128 -0
- omnigenome/src/__init__.py +12 -0
- omnigenome/src/abc/__init__.py +12 -0
- omnigenome/src/abc/abstract_dataset.py +622 -0
- omnigenome/src/abc/abstract_metric.py +114 -0
- omnigenome/src/abc/abstract_model.py +689 -0
- omnigenome/src/abc/abstract_tokenizer.py +267 -0
- omnigenome/src/dataset/__init__.py +16 -0
- omnigenome/src/dataset/omni_dataset.py +435 -0
- omnigenome/src/lora/__init__.py +13 -0
- omnigenome/src/lora/lora_model.py +294 -0
- omnigenome/src/metric/__init__.py +15 -0
- omnigenome/src/metric/classification_metric.py +184 -0
- omnigenome/src/metric/metric.py +199 -0
- omnigenome/src/metric/ranking_metric.py +142 -0
- omnigenome/src/metric/regression_metric.py +191 -0
- omnigenome/src/misc/__init__.py +3 -0
- omnigenome/src/misc/utils.py +439 -0
- omnigenome/src/model/__init__.py +19 -0
- omnigenome/src/model/augmentation/__init__.py +12 -0
- omnigenome/src/model/augmentation/model.py +219 -0
- omnigenome/src/model/classification/__init__.py +12 -0
- omnigenome/src/model/classification/model.py +642 -0
- omnigenome/src/model/embedding/__init__.py +12 -0
- omnigenome/src/model/embedding/model.py +263 -0
- omnigenome/src/model/mlm/__init__.py +12 -0
- omnigenome/src/model/mlm/model.py +177 -0
- omnigenome/src/model/module_utils.py +232 -0
- omnigenome/src/model/regression/__init__.py +12 -0
- omnigenome/src/model/regression/model.py +786 -0
- omnigenome/src/model/regression/resnet.py +483 -0
- omnigenome/src/model/rna_design/__init__.py +12 -0
- omnigenome/src/model/rna_design/model.py +426 -0
- omnigenome/src/model/seq2seq/__init__.py +12 -0
- omnigenome/src/model/seq2seq/model.py +44 -0
- omnigenome/src/tokenizer/__init__.py +16 -0
- omnigenome/src/tokenizer/bpe_tokenizer.py +226 -0
- omnigenome/src/tokenizer/kmers_tokenizer.py +247 -0
- omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +249 -0
- omnigenome/src/trainer/__init__.py +14 -0
- omnigenome/src/trainer/accelerate_trainer.py +739 -0
- omnigenome/src/trainer/hf_trainer.py +75 -0
- omnigenome/src/trainer/trainer.py +579 -0
- omnigenome/utility/__init__.py +3 -0
- omnigenome/utility/dataset_hub/__init__.py +13 -0
- omnigenome/utility/dataset_hub/dataset_hub.py +178 -0
- omnigenome/utility/ensemble.py +324 -0
- omnigenome/utility/hub_utils.py +517 -0
- omnigenome/utility/model_hub/__init__.py +12 -0
- omnigenome/utility/model_hub/model_hub.py +231 -0
- omnigenome/utility/pipeline_hub/__init__.py +12 -0
- omnigenome/utility/pipeline_hub/pipeline.py +483 -0
- omnigenome/utility/pipeline_hub/pipeline_hub.py +129 -0
- omnigenome-0.3.0a0.dist-info/METADATA +224 -0
- omnigenome-0.3.0a0.dist-info/RECORD +85 -0
- omnigenome-0.3.0a0.dist-info/WHEEL +5 -0
- omnigenome-0.3.0a0.dist-info/entry_points.txt +3 -0
- omnigenome-0.3.0a0.dist-info/licenses/LICENSE +201 -0
- omnigenome-0.3.0a0.dist-info/top_level.txt +2 -0
- tests/__init__.py +9 -0
- tests/conftest.py +160 -0
- tests/test_dataset_patterns.py +291 -0
- tests/test_examples_syntax.py +83 -0
- tests/test_model_loading.py +183 -0
- tests/test_rna_functions.py +255 -0
- tests/test_training_patterns.py +302 -0
|
@@ -0,0 +1,484 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# file: auto_bench.py
|
|
3
|
+
# time: 11:54 14/04/2024
|
|
4
|
+
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
+
# github: https://github.com/yangheng95
|
|
6
|
+
# huggingface: https://huggingface.co/yangheng
|
|
7
|
+
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
8
|
+
# Copyright (C) 2019-2024. All Rights Reserved.
|
|
9
|
+
|
|
10
|
+
import os
|
|
11
|
+
import time
|
|
12
|
+
import warnings
|
|
13
|
+
|
|
14
|
+
import findfile
|
|
15
|
+
import torch
|
|
16
|
+
from metric_visualizer import MetricVisualizer
|
|
17
|
+
|
|
18
|
+
from transformers import TrainingArguments, Trainer as HFTrainer
|
|
19
|
+
from ...src.abc.abstract_tokenizer import OmniTokenizer
|
|
20
|
+
from ...src.lora.lora_model import OmniLoraModel
|
|
21
|
+
from ...src.misc.utils import (
|
|
22
|
+
seed_everything,
|
|
23
|
+
fprint,
|
|
24
|
+
load_module_from_path,
|
|
25
|
+
check_bench_version,
|
|
26
|
+
clean_temp_checkpoint,
|
|
27
|
+
)
|
|
28
|
+
from ...src.trainer.trainer import Trainer
|
|
29
|
+
from ...src.trainer.accelerate_trainer import AccelerateTrainer
|
|
30
|
+
from ...utility.hub_utils import download_benchmark
|
|
31
|
+
from ... import __version__ as omnigenome_version
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class AutoBench:
|
|
35
|
+
"""
|
|
36
|
+
AutoBench is a class for automatically benchmarking genomic foundation models.
|
|
37
|
+
|
|
38
|
+
This class provides a comprehensive framework for evaluating genomic models
|
|
39
|
+
across multiple benchmarks and tasks. It handles loading benchmarks, models,
|
|
40
|
+
tokenizers, and running evaluations with proper metric tracking and result
|
|
41
|
+
visualization.
|
|
42
|
+
|
|
43
|
+
AutoBench supports various evaluation scenarios including:
|
|
44
|
+
- Single model evaluation across multiple benchmarks
|
|
45
|
+
- Multi-seed evaluation for robustness testing
|
|
46
|
+
- Different trainer backends (native, accelerate, huggingface)
|
|
47
|
+
- Automatic metric visualization and result tracking
|
|
48
|
+
|
|
49
|
+
Attributes:
|
|
50
|
+
benchmark (str): The name or path of the benchmark to use.
|
|
51
|
+
model_name_or_path (str): The name or path of the model to evaluate.
|
|
52
|
+
tokenizer: The tokenizer to use for evaluation.
|
|
53
|
+
autocast (str): The autocast precision to use ('fp16', 'bf16', etc.).
|
|
54
|
+
overwrite (bool): Whether to overwrite existing evaluation results.
|
|
55
|
+
trainer (str): The trainer to use ('native', 'accelerate', 'hf_trainer').
|
|
56
|
+
mv_path (str): Path to the metric visualizer file.
|
|
57
|
+
mv (MetricVisualizer): The metric visualizer instance.
|
|
58
|
+
bench_metadata: Metadata about the benchmark configuration.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
benchmark,
|
|
64
|
+
model_name_or_path,
|
|
65
|
+
tokenizer=None,
|
|
66
|
+
**kwargs,
|
|
67
|
+
):
|
|
68
|
+
"""
|
|
69
|
+
Initializes the AutoBench instance.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
benchmark (str): The name or path of the benchmark to use.
|
|
73
|
+
model_name_or_path (str): The name or path of the model to evaluate.
|
|
74
|
+
tokenizer: The tokenizer to use. If None, it will be loaded from the model path.
|
|
75
|
+
**kwargs: Additional keyword arguments.
|
|
76
|
+
- autocast (str): The autocast precision to use ('fp16', 'bf16', etc.).
|
|
77
|
+
Defaults to 'fp16'.
|
|
78
|
+
- overwrite (bool): Whether to overwrite existing evaluation results.
|
|
79
|
+
Defaults to False.
|
|
80
|
+
- trainer (str): The trainer to use ('native', 'accelerate', 'hf_trainer').
|
|
81
|
+
Defaults to 'native'.
|
|
82
|
+
|
|
83
|
+
Example:
|
|
84
|
+
>>> # Initialize with a benchmark and model
|
|
85
|
+
>>> bench = AutoBench("RGB", "model_name")
|
|
86
|
+
|
|
87
|
+
>>> # Initialize with custom settings
|
|
88
|
+
>>> bench = AutoBench("RGB", "model_name",
|
|
89
|
+
... autocast="bf16", trainer="accelerate")
|
|
90
|
+
"""
|
|
91
|
+
self.benchmark = benchmark.rstrip("/")
|
|
92
|
+
self.autocast = kwargs.pop("autocast", "fp16")
|
|
93
|
+
self.overwrite = kwargs.pop("overwrite", False)
|
|
94
|
+
self.trainer = kwargs.pop("trainer", "native")
|
|
95
|
+
|
|
96
|
+
self.model_name_or_path = model_name_or_path
|
|
97
|
+
self.tokenizer = tokenizer
|
|
98
|
+
if isinstance(self.model_name_or_path, str):
|
|
99
|
+
self.model_name_or_path = self.model_name_or_path.rstrip("/")
|
|
100
|
+
self.model_name = self.model_name_or_path.split("/")[-1]
|
|
101
|
+
else:
|
|
102
|
+
self.model_name = self.model_name_or_path.__class__.__name__
|
|
103
|
+
if isinstance(tokenizer, str):
|
|
104
|
+
self.tokenizer = tokenizer.rstrip("/")
|
|
105
|
+
os.makedirs("./autobench_evaluations", exist_ok=True)
|
|
106
|
+
time_str = time.strftime("%Y%m%d_%H%M%S", time.localtime())
|
|
107
|
+
mv_name = f"{benchmark}-{self.model_name}"
|
|
108
|
+
self.mv_path = f"./autobench_evaluations/{mv_name}-{time_str}.mv"
|
|
109
|
+
|
|
110
|
+
mv_paths = findfile.find_files(
|
|
111
|
+
"./autobench_evaluations",
|
|
112
|
+
and_key=[benchmark, self.model_name, ".mv"],
|
|
113
|
+
)
|
|
114
|
+
if mv_paths and not self.overwrite:
|
|
115
|
+
self.mv = MetricVisualizer.load(mv_paths[-1])
|
|
116
|
+
self.mv.summary(round=4)
|
|
117
|
+
else:
|
|
118
|
+
self.mv = MetricVisualizer(self.mv_path)
|
|
119
|
+
if not os.path.exists(self.benchmark):
|
|
120
|
+
fprint(
|
|
121
|
+
"Benchmark:",
|
|
122
|
+
benchmark,
|
|
123
|
+
"does not exist. Search online for available benchmarks.",
|
|
124
|
+
)
|
|
125
|
+
self.benchmark = download_benchmark(self.benchmark)
|
|
126
|
+
|
|
127
|
+
# Import benchmark list
|
|
128
|
+
self.bench_metadata = load_module_from_path(
|
|
129
|
+
f"bench_metadata", f"{self.benchmark}/metadata.py"
|
|
130
|
+
)
|
|
131
|
+
check_bench_version(
|
|
132
|
+
self.bench_metadata.__omnigenome_version__, omnigenome_version
|
|
133
|
+
)
|
|
134
|
+
fprint("Loaded benchmarks: ", self.bench_metadata.bench_list)
|
|
135
|
+
self.bench_info()
|
|
136
|
+
|
|
137
|
+
def bench_info(self):
|
|
138
|
+
"""
|
|
139
|
+
Prints and returns information about the current benchmark setup.
|
|
140
|
+
|
|
141
|
+
This method provides a comprehensive overview of the current
|
|
142
|
+
benchmark configuration, including benchmark details, model information,
|
|
143
|
+
and evaluation settings.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
str: A string containing benchmark information.
|
|
147
|
+
|
|
148
|
+
Example:
|
|
149
|
+
>>> info = bench.bench_info()
|
|
150
|
+
>>> print(info)
|
|
151
|
+
"""
|
|
152
|
+
info = f"Benchmark Root: {self.benchmark}\n"
|
|
153
|
+
info += f"Benchmark List: {self.bench_metadata.bench_list}\n"
|
|
154
|
+
info += f"Model Name or Path: {self.model_name}\n"
|
|
155
|
+
info += f"Tokenizer: {self.tokenizer}\n"
|
|
156
|
+
info += f"Metric Visualizer Path: {self.mv_path}\n"
|
|
157
|
+
info += f"BenchConfig Details: {self.bench_metadata}\n"
|
|
158
|
+
fprint(info)
|
|
159
|
+
return info
|
|
160
|
+
|
|
161
|
+
def run(self, **kwargs):
|
|
162
|
+
"""
|
|
163
|
+
Runs the benchmarking process.
|
|
164
|
+
|
|
165
|
+
This method iterates through the tasks in the benchmark, loads the corresponding
|
|
166
|
+
configurations, initializes the model, tokenizer, and datasets, and then
|
|
167
|
+
trains and evaluates the model. It supports multiple evaluation seeds and
|
|
168
|
+
various trainer backends.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
**kwargs: Additional keyword arguments that will override the default
|
|
172
|
+
parameters in the benchmark configuration.
|
|
173
|
+
|
|
174
|
+
Example:
|
|
175
|
+
>>> # Run benchmarking with default settings
|
|
176
|
+
>>> bench.run()
|
|
177
|
+
|
|
178
|
+
>>> # Run with custom parameters
|
|
179
|
+
>>> bench.run(learning_rate=1e-4, batch_size=16)
|
|
180
|
+
"""
|
|
181
|
+
bs_scale = kwargs.pop("bs_scale", 1)
|
|
182
|
+
# Import benchmark config
|
|
183
|
+
for _, bench in enumerate(self.bench_metadata.bench_list):
|
|
184
|
+
clean_temp_checkpoint(1) # clean temp checkpoint older than 1 day
|
|
185
|
+
fprint(
|
|
186
|
+
">" * 80,
|
|
187
|
+
f"\nRunning evaluation for task: {bench}",
|
|
188
|
+
"Progress: ",
|
|
189
|
+
_ + 1,
|
|
190
|
+
"/",
|
|
191
|
+
len(self.bench_metadata.bench_list),
|
|
192
|
+
f"{(_ + 1) * 100 / len(self.bench_metadata.bench_list)}%",
|
|
193
|
+
)
|
|
194
|
+
bench_config_path = findfile.find_file(
|
|
195
|
+
self.benchmark, and_key=f"{self.benchmark}.{bench}.config".split(".")
|
|
196
|
+
)
|
|
197
|
+
config = load_module_from_path("config", bench_config_path)
|
|
198
|
+
bench_config = config.bench_config
|
|
199
|
+
fprint(f"Loaded config for {bench} from {bench_config_path}")
|
|
200
|
+
fprint(bench_config)
|
|
201
|
+
|
|
202
|
+
# Init Tokenizer and Model
|
|
203
|
+
if not self.tokenizer:
|
|
204
|
+
tokenizer = OmniTokenizer.from_pretrained(
|
|
205
|
+
self.model_name_or_path,
|
|
206
|
+
trust_remote_code=bench_config.get("trust_remote_code", True),
|
|
207
|
+
**bench_config,
|
|
208
|
+
)
|
|
209
|
+
else:
|
|
210
|
+
tokenizer = self.tokenizer
|
|
211
|
+
|
|
212
|
+
if not isinstance(bench_config["seeds"], list):
|
|
213
|
+
bench_config["seeds"] = [bench_config["seeds"]]
|
|
214
|
+
|
|
215
|
+
random_seeds = bench_config["seeds"]
|
|
216
|
+
for seed in random_seeds:
|
|
217
|
+
_kwargs = kwargs.copy()
|
|
218
|
+
for key, value in _kwargs.items():
|
|
219
|
+
if key in bench_config:
|
|
220
|
+
fprint(
|
|
221
|
+
"Override", key, "with", value, "according to the input kwargs"
|
|
222
|
+
)
|
|
223
|
+
bench_config.update({key: value})
|
|
224
|
+
|
|
225
|
+
else:
|
|
226
|
+
warnings.warn(
|
|
227
|
+
f"kwarg: {key} not found in bench_config while setting {key} = {value}"
|
|
228
|
+
)
|
|
229
|
+
bench_config.update({key: value})
|
|
230
|
+
|
|
231
|
+
for key, value in bench_config.items():
|
|
232
|
+
if key in bench_config and key in _kwargs:
|
|
233
|
+
_kwargs.pop(key)
|
|
234
|
+
|
|
235
|
+
fprint(
|
|
236
|
+
f"AutoBench Config for {bench}:",
|
|
237
|
+
"\n".join([f"{k}: {v}" for k, v in bench_config.items()]),
|
|
238
|
+
)
|
|
239
|
+
for key, value in _kwargs.items():
|
|
240
|
+
if key in bench_config:
|
|
241
|
+
fprint(
|
|
242
|
+
"Override", key, "with", value, "according to the input kwargs"
|
|
243
|
+
)
|
|
244
|
+
bench_config.update({key: value})
|
|
245
|
+
|
|
246
|
+
else:
|
|
247
|
+
warnings.warn(
|
|
248
|
+
f"kwarg: {key} not found in bench_config while setting {key} = {value}"
|
|
249
|
+
)
|
|
250
|
+
bench_config.update({key: value})
|
|
251
|
+
|
|
252
|
+
for key, value in bench_config.items():
|
|
253
|
+
if key in bench_config and key in _kwargs:
|
|
254
|
+
_kwargs.pop(key)
|
|
255
|
+
|
|
256
|
+
fprint(
|
|
257
|
+
f"AutoBench Config for {bench}:",
|
|
258
|
+
"\n".join([f"{k}: {v}" for k, v in bench_config.items()]),
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
batch_size = (
|
|
262
|
+
bench_config["batch_size"] if "batch_size" in bench_config else 8
|
|
263
|
+
) * bs_scale
|
|
264
|
+
|
|
265
|
+
record_name = f"{self.benchmark}-{bench}-{self.model_name}".split("/")[
|
|
266
|
+
-1
|
|
267
|
+
]
|
|
268
|
+
# check if the record exists
|
|
269
|
+
if record_name in self.mv.transpose() and len(
|
|
270
|
+
list(self.mv.transpose()[record_name].values())[0]
|
|
271
|
+
) >= len(bench_config["seeds"]):
|
|
272
|
+
continue
|
|
273
|
+
|
|
274
|
+
seed_everything(seed)
|
|
275
|
+
if self.model_name_or_path:
|
|
276
|
+
model_cls = bench_config["model_cls"]
|
|
277
|
+
model = model_cls(
|
|
278
|
+
self.model_name_or_path,
|
|
279
|
+
tokenizer=tokenizer,
|
|
280
|
+
label2id=bench_config.label2id,
|
|
281
|
+
num_labels=bench_config["num_labels"],
|
|
282
|
+
trust_remote_code=True,
|
|
283
|
+
ignore_mismatched_sizes=True,
|
|
284
|
+
)
|
|
285
|
+
else:
|
|
286
|
+
raise ValueError(
|
|
287
|
+
"model_name_or_path is not specified. Please provide a valid model name or path."
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
fprint(f"\n{model}")
|
|
291
|
+
|
|
292
|
+
if kwargs.get("lora_config", None) is not None:
|
|
293
|
+
fprint("Applying LoRA to the model with config:", kwargs["lora_config"])
|
|
294
|
+
model = OmniLoraModel(model, **kwargs.get("lora_config", {}))
|
|
295
|
+
|
|
296
|
+
# Init Trainer
|
|
297
|
+
dataset_cls = bench_config["dataset_cls"]
|
|
298
|
+
|
|
299
|
+
if hasattr(model.config, "max_position_embeddings"):
|
|
300
|
+
max_length = min(
|
|
301
|
+
bench_config["max_length"],
|
|
302
|
+
model.config.max_position_embeddings,
|
|
303
|
+
)
|
|
304
|
+
else:
|
|
305
|
+
max_length = bench_config["max_length"]
|
|
306
|
+
|
|
307
|
+
train_set = dataset_cls(
|
|
308
|
+
data_source=bench_config["train_file"],
|
|
309
|
+
tokenizer=tokenizer,
|
|
310
|
+
label2id=bench_config["label2id"],
|
|
311
|
+
max_length=max_length,
|
|
312
|
+
structure_in=bench_config.get("structure_in", False),
|
|
313
|
+
max_examples=bench_config.get("max_examples", None),
|
|
314
|
+
shuffle=bench_config.get("shuffle", True),
|
|
315
|
+
drop_long_seq=bench_config.get("drop_long_seq", False),
|
|
316
|
+
**_kwargs,
|
|
317
|
+
)
|
|
318
|
+
test_set = dataset_cls(
|
|
319
|
+
data_source=bench_config["test_file"],
|
|
320
|
+
tokenizer=tokenizer,
|
|
321
|
+
label2id=bench_config["label2id"],
|
|
322
|
+
max_length=max_length,
|
|
323
|
+
structure_in=bench_config.get("structure_in", False),
|
|
324
|
+
max_examples=bench_config.get("max_examples", None),
|
|
325
|
+
shuffle=False,
|
|
326
|
+
drop_long_seq=bench_config.get("drop_long_seq", False),
|
|
327
|
+
**_kwargs,
|
|
328
|
+
)
|
|
329
|
+
valid_set = dataset_cls(
|
|
330
|
+
data_source=bench_config["valid_file"],
|
|
331
|
+
tokenizer=tokenizer,
|
|
332
|
+
label2id=bench_config["label2id"],
|
|
333
|
+
max_length=max_length,
|
|
334
|
+
structure_in=bench_config.get("structure_in", False),
|
|
335
|
+
max_examples=bench_config.get("max_examples", None),
|
|
336
|
+
shuffle=False,
|
|
337
|
+
drop_long_seq=bench_config.get("drop_long_seq", False),
|
|
338
|
+
**_kwargs,
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
if self.trainer == "hf_trainer":
|
|
342
|
+
# Set up HuggingFace Trainer
|
|
343
|
+
hf_kwargs = {
|
|
344
|
+
k: v
|
|
345
|
+
for k, v in kwargs.items()
|
|
346
|
+
if hasattr(TrainingArguments, k) and k != "output_dir"
|
|
347
|
+
}
|
|
348
|
+
training_args = TrainingArguments(
|
|
349
|
+
output_dir=f"./autobench_evaluations/{self.model_name}-{bench}",
|
|
350
|
+
num_train_epochs=hf_kwargs.pop(
|
|
351
|
+
"num_train_epochs", bench_config["epochs"]
|
|
352
|
+
),
|
|
353
|
+
per_device_train_batch_size=hf_kwargs.pop(
|
|
354
|
+
"batch_size", batch_size
|
|
355
|
+
),
|
|
356
|
+
per_device_eval_batch_size=hf_kwargs.pop(
|
|
357
|
+
"batch_size", batch_size
|
|
358
|
+
),
|
|
359
|
+
gradient_accumulation_steps=hf_kwargs.pop(
|
|
360
|
+
"gradient_accumulation_steps", 1
|
|
361
|
+
),
|
|
362
|
+
learning_rate=hf_kwargs.pop("learning_rate", 2e-5),
|
|
363
|
+
weight_decay=hf_kwargs.pop("weight_decay", 0),
|
|
364
|
+
eval_strategy=hf_kwargs.pop("eval_strategy", "epoch"),
|
|
365
|
+
save_strategy=hf_kwargs.pop("save_strategy", "epoch"),
|
|
366
|
+
fp16=hf_kwargs.pop("fp16", True),
|
|
367
|
+
remove_unused_columns=False,
|
|
368
|
+
label_names=["labels"],
|
|
369
|
+
**hf_kwargs,
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
valid_set = valid_set if len(valid_set) else test_set
|
|
373
|
+
|
|
374
|
+
if len(bench_config["compute_metrics"]) > 1:
|
|
375
|
+
fprint(
|
|
376
|
+
"Multiple metrics not supported by HFTrainer, using the first one metric only."
|
|
377
|
+
)
|
|
378
|
+
trainer = HFTrainer(
|
|
379
|
+
model=model,
|
|
380
|
+
args=training_args,
|
|
381
|
+
train_dataset=train_set,
|
|
382
|
+
eval_dataset=valid_set,
|
|
383
|
+
compute_metrics=(
|
|
384
|
+
bench_config["compute_metrics"][0]
|
|
385
|
+
if isinstance(bench_config["compute_metrics"], list)
|
|
386
|
+
else bench_config["compute_metrics"]
|
|
387
|
+
),
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
# Train and evaluate
|
|
391
|
+
eval_result = trainer.evaluate(
|
|
392
|
+
valid_set if len(valid_set) else test_set
|
|
393
|
+
)
|
|
394
|
+
print(eval_result)
|
|
395
|
+
train_result = trainer.train()
|
|
396
|
+
eval_result = trainer.evaluate()
|
|
397
|
+
test_result = trainer.evaluate(
|
|
398
|
+
test_set if len(test_set) else valid_set
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
metrics = {
|
|
402
|
+
"train": train_result.metrics,
|
|
403
|
+
"eval": eval_result,
|
|
404
|
+
"test": test_result,
|
|
405
|
+
}
|
|
406
|
+
fprint(metrics)
|
|
407
|
+
else:
|
|
408
|
+
optimizer = torch.optim.AdamW(
|
|
409
|
+
filter(lambda p: p.requires_grad, model.parameters()),
|
|
410
|
+
lr=(
|
|
411
|
+
bench_config["learning_rate"]
|
|
412
|
+
if "learning_rate" in bench_config
|
|
413
|
+
else 2e-5
|
|
414
|
+
),
|
|
415
|
+
weight_decay=(
|
|
416
|
+
bench_config["weight_decay"]
|
|
417
|
+
if "weight_decay" in bench_config
|
|
418
|
+
else 0
|
|
419
|
+
),
|
|
420
|
+
)
|
|
421
|
+
if self.trainer == "accelerate":
|
|
422
|
+
trainer_cls = AccelerateTrainer
|
|
423
|
+
else:
|
|
424
|
+
trainer_cls = Trainer
|
|
425
|
+
fprint(f"Using Trainer: {trainer_cls}")
|
|
426
|
+
trainer = trainer_cls(
|
|
427
|
+
model=model,
|
|
428
|
+
train_dataset=train_set,
|
|
429
|
+
eval_dataset=valid_set,
|
|
430
|
+
test_dataset=test_set,
|
|
431
|
+
batch_size=batch_size,
|
|
432
|
+
patience=(
|
|
433
|
+
bench_config["patience"]
|
|
434
|
+
if "patience" in bench_config
|
|
435
|
+
else 3
|
|
436
|
+
),
|
|
437
|
+
epochs=bench_config["epochs"],
|
|
438
|
+
gradient_accumulation_steps=bench_config.get(
|
|
439
|
+
"gradient_accumulation_steps", 1
|
|
440
|
+
),
|
|
441
|
+
optimizer=optimizer,
|
|
442
|
+
loss_fn=(
|
|
443
|
+
bench_config["loss_fn"]
|
|
444
|
+
if "loss_fn" in bench_config
|
|
445
|
+
else None
|
|
446
|
+
),
|
|
447
|
+
compute_metrics=bench_config["compute_metrics"],
|
|
448
|
+
seed=seed,
|
|
449
|
+
autocast=self.autocast,
|
|
450
|
+
**_kwargs,
|
|
451
|
+
)
|
|
452
|
+
metrics = trainer.train()
|
|
453
|
+
|
|
454
|
+
predictions = trainer.predictions
|
|
455
|
+
|
|
456
|
+
if bench_config.get("save_predictions", False):
|
|
457
|
+
os.makedirs(f"predictions/{bench}", exist_ok=True)
|
|
458
|
+
import numpy as np
|
|
459
|
+
|
|
460
|
+
for split in predictions.keys():
|
|
461
|
+
with open(
|
|
462
|
+
f"predictions/{bench}/{split}.npy",
|
|
463
|
+
"wb",
|
|
464
|
+
) as f:
|
|
465
|
+
np.save(f, predictions[split])
|
|
466
|
+
|
|
467
|
+
if metrics:
|
|
468
|
+
for key, value in metrics["test"][-1].items():
|
|
469
|
+
try:
|
|
470
|
+
value = float(value)
|
|
471
|
+
except:
|
|
472
|
+
pass # ignore non-float values
|
|
473
|
+
self.mv.log(f"{record_name}", f"{key}", value)
|
|
474
|
+
# for key, value in metrics['test'][-1].items():
|
|
475
|
+
# self.mv.log(f'{record_name}', f'test_{key}', value)
|
|
476
|
+
# for i, valid_metrics in enumerate(metrics["valid"]):
|
|
477
|
+
# for key, value in valid_metrics.items():
|
|
478
|
+
# self.mv.log(f'{record_name}', f'valid_epoch_{i}_{key}', value)
|
|
479
|
+
|
|
480
|
+
self.mv.summary(round=4)
|
|
481
|
+
self.mv.dump(self.mv_path)
|
|
482
|
+
self.mv.to_csv(self.mv_path.replace(".mv", ".csv"))
|
|
483
|
+
del model, trainer, optimizer
|
|
484
|
+
torch.cuda.empty_cache()
|